Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
37a9836
1
Parent(s):
6e4576a
add code
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +179 -0
- LICENSE +21 -0
- README.md +92 -14
- app.py +191 -0
- config.py +12 -0
- core/__init__.py +0 -0
- core/bark/__init__.py +5 -0
- core/bark/constants.py +18 -0
- core/bark/custom_context.py +79 -0
- core/bark/encodec.py +63 -0
- core/bark/generate_audio.py +117 -0
- core/bark/generate_audio_semantic_dataset.py +122 -0
- core/bark/generate_coarse.py +385 -0
- core/bark/generate_fine.py +210 -0
- core/bark/generate_semantic.py +361 -0
- core/bark/voice_clone.py +104 -0
- core/data_model/__init__.py +1 -0
- core/data_model/bark.py +337 -0
- core/memory/__init__.py +5 -0
- core/memory/common.py +187 -0
- core/memory/model_manager.py +289 -0
- core/memory/models.py +169 -0
- core/model/__init__.py +1 -0
- core/model/bark.py +425 -0
- core/model/hubert.py +237 -0
- core/trainer/__init__.py +1 -0
- core/trainer/custom_hubert_trainer.py +555 -0
- core/utils/__init__.py +7 -0
- core/utils/audio.py +104 -0
- core/utils/huggingface.py +169 -0
- core/utils/read_write_files.py +46 -0
- core/utils/text.py +13 -0
- event_handlers.py +436 -0
- generate_audio_semantic_dataset.py +155 -0
- prompts/de_speaker_0.npz +0 -0
- prompts/de_speaker_1.npz +0 -0
- prompts/de_speaker_2.npz +0 -0
- prompts/de_speaker_3.npz +0 -0
- prompts/de_speaker_4.npz +0 -0
- prompts/de_speaker_5.npz +0 -0
- prompts/de_speaker_6.npz +0 -0
- prompts/de_speaker_7.npz +0 -0
- prompts/de_speaker_8.npz +0 -0
- prompts/de_speaker_9.npz +0 -0
- prompts/en_speaker_0.npz +0 -0
- prompts/en_speaker_1.npz +0 -0
- prompts/en_speaker_2.npz +0 -0
- prompts/en_speaker_3.npz +0 -0
- prompts/en_speaker_4.npz +0 -0
- prompts/en_speaker_5.npz +0 -0
.gitignore
ADDED
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
share/python-wheels/
|
24 |
+
*.egg-info/
|
25 |
+
.installed.cfg
|
26 |
+
*.egg
|
27 |
+
MANIFEST
|
28 |
+
|
29 |
+
# PyInstaller
|
30 |
+
# Usually these files are written by a python script from a template
|
31 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
32 |
+
*.manifest
|
33 |
+
*.spec
|
34 |
+
|
35 |
+
# Installer logs
|
36 |
+
pip-log.txt
|
37 |
+
pip-delete-this-directory.txt
|
38 |
+
|
39 |
+
# Unit test / coverage reports
|
40 |
+
htmlcov/
|
41 |
+
.tox/
|
42 |
+
.nox/
|
43 |
+
.coverage
|
44 |
+
.coverage.*
|
45 |
+
.cache
|
46 |
+
nosetests.xml
|
47 |
+
coverage.xml
|
48 |
+
*.cover
|
49 |
+
*.py,cover
|
50 |
+
.hypothesis/
|
51 |
+
.pytest_cache/
|
52 |
+
cover/
|
53 |
+
|
54 |
+
# Translations
|
55 |
+
*.mo
|
56 |
+
*.pot
|
57 |
+
|
58 |
+
# Django stuff:
|
59 |
+
*.log
|
60 |
+
local_settings.py
|
61 |
+
db.sqlite3
|
62 |
+
db.sqlite3-journal
|
63 |
+
|
64 |
+
# Flask stuff:
|
65 |
+
instance/
|
66 |
+
.webassets-cache
|
67 |
+
|
68 |
+
# Scrapy stuff:
|
69 |
+
.scrapy
|
70 |
+
|
71 |
+
# Sphinx documentation
|
72 |
+
docs/_build/
|
73 |
+
|
74 |
+
# PyBuilder
|
75 |
+
.pybuilder/
|
76 |
+
target/
|
77 |
+
|
78 |
+
# Jupyter Notebook
|
79 |
+
.ipynb_checkpoints
|
80 |
+
|
81 |
+
# IPython
|
82 |
+
profile_default/
|
83 |
+
ipython_config.py
|
84 |
+
|
85 |
+
# pyenv
|
86 |
+
# For a library or package, you might want to ignore these files since the code is
|
87 |
+
# intended to run in multiple environments; otherwise, check them in:
|
88 |
+
# .python-version
|
89 |
+
|
90 |
+
# pipenv
|
91 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
92 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
93 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
94 |
+
# install all needed dependencies.
|
95 |
+
#Pipfile.lock
|
96 |
+
|
97 |
+
# UV
|
98 |
+
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
|
99 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
100 |
+
# commonly ignored for libraries.
|
101 |
+
#uv.lock
|
102 |
+
|
103 |
+
# poetry
|
104 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
105 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
106 |
+
# commonly ignored for libraries.
|
107 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
108 |
+
#poetry.lock
|
109 |
+
|
110 |
+
# pdm
|
111 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
112 |
+
#pdm.lock
|
113 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
114 |
+
# in version control.
|
115 |
+
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
116 |
+
.pdm.toml
|
117 |
+
.pdm-python
|
118 |
+
.pdm-build/
|
119 |
+
|
120 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
121 |
+
__pypackages__/
|
122 |
+
|
123 |
+
# Celery stuff
|
124 |
+
celerybeat-schedule
|
125 |
+
celerybeat.pid
|
126 |
+
|
127 |
+
# SageMath parsed files
|
128 |
+
*.sage.py
|
129 |
+
|
130 |
+
# Environments
|
131 |
+
.env
|
132 |
+
.venv
|
133 |
+
env/
|
134 |
+
venv/
|
135 |
+
ENV/
|
136 |
+
env.bak/
|
137 |
+
venv.bak/
|
138 |
+
|
139 |
+
# Spyder project settings
|
140 |
+
.spyderproject
|
141 |
+
.spyproject
|
142 |
+
|
143 |
+
# Rope project settings
|
144 |
+
.ropeproject
|
145 |
+
|
146 |
+
# mkdocs documentation
|
147 |
+
/site
|
148 |
+
|
149 |
+
# mypy
|
150 |
+
.mypy_cache/
|
151 |
+
.dmypy.json
|
152 |
+
dmypy.json
|
153 |
+
|
154 |
+
# Pyre type checker
|
155 |
+
.pyre/
|
156 |
+
|
157 |
+
# pytype static type analyzer
|
158 |
+
.pytype/
|
159 |
+
|
160 |
+
# Cython debug symbols
|
161 |
+
cython_debug/
|
162 |
+
|
163 |
+
# PyCharm
|
164 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
165 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
166 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
167 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
168 |
+
#.idea/
|
169 |
+
|
170 |
+
# Ruff stuff:
|
171 |
+
.ruff_cache/
|
172 |
+
|
173 |
+
# PyPI configuration file
|
174 |
+
.pypirc
|
175 |
+
bark_prompts/
|
176 |
+
generated_audio/
|
177 |
+
|
178 |
+
models/
|
179 |
+
.DS_Store
|
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2025 Hao Huynh Nhat
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
README.md
CHANGED
@@ -1,14 +1,92 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Generate Audio from text and clone voice with BARK
|
2 |
+
|
3 |
+
You can generate audio from text with natural sounding voice and clone any voice (not perfect).
|
4 |
+

|
5 |
+
|
6 |
+
Code worked on Python 3.12. May also work on other versions.
|
7 |
+
|
8 |
+
Example generated audio in the /assets/audio folder
|
9 |
+
|
10 |
+
## Features
|
11 |
+
|
12 |
+
- **Text-to-Audio Generation:** Generate speech from text using the BARK model (supports 'small' and 'large' variants).
|
13 |
+
- **Parameter Control:** Adjust semantic, coarse, and fine temperature settings for generation diversity. Set a generation seed for reproducibility.
|
14 |
+
- **Device Selection:** Run inference on available devices (CPU, CUDA, MPS).
|
15 |
+
- **Standard Voice Prompts:** Utilize built-in BARK voice prompts (`.npz` files) located in the `bark_prompts` directory.
|
16 |
+
- **Custom Voice Prompt Creation (Voice Cloning):**
|
17 |
+
- Upload your own audio file (.wav, .mp3).
|
18 |
+
- Generate a BARK-compatible semantic prompt (`.npz` file) using a custom-trained HuBERT model.
|
19 |
+
- The generated prompt appears in the "Select Voice Prompt" dropdown for immediate use.
|
20 |
+
- **Audio Management:** View, play, and delete generated audio files directly within the interface.
|
21 |
+
- **Training Scripts:** Includes scripts to generate the necessary dataset (`generate_audio_semantic_dataset.py`) and train the custom HuBERT model (`train_hubert.py`).
|
22 |
+
|
23 |
+
## Custom Voice Cloning Model
|
24 |
+
|
25 |
+
The core of the custom voice prompt generation relies on a fine-tuned HuBERT model.
|
26 |
+
|
27 |
+
- **Model:** `sleeper371/hubert-for-bark-semantic` on Hugging Face ([Link](https://huggingface.co/sleeper371/hubert-for-bark-semantic))
|
28 |
+
- **Architecture:** This model uses a HuBERT base feature extractor followed by a Transformer decoder head.
|
29 |
+
- **Training:** It was trained on over 4700 sentence pairs, mapping audio waveforms to the semantic tokens generated by BARK's semantic model. The training used a cross-entropy loss objective.
|
30 |
+
- **Dataset:** The training dataset is available at `sleeper371/bark-wave-semantic` on Hugging Face ([Link](https://huggingface.co/datasets/sleeper371/bark-wave-semantic)).
|
31 |
+
- **Comparison:** This approach is inspired by projects like [gitmylo/bark-data-gen](https://github.com/gitmylo/bark-data-gen), but differs in the head architecture (he used an LSTM head while I used a transformers decoder head)
|
32 |
+
|
33 |
+
## Setup and Installation
|
34 |
+
|
35 |
+
Follow these steps to set up the environment and run the application.
|
36 |
+
|
37 |
+
1. **Clone the Repository:**
|
38 |
+
|
39 |
+
2. **Create a Virtual Environment:**
|
40 |
+
It's highly recommended to use a virtual environment to manage dependencies.
|
41 |
+
|
42 |
+
```bash
|
43 |
+
# For Linux/macOS
|
44 |
+
python3 -m venv venv
|
45 |
+
source venv/bin/activate
|
46 |
+
|
47 |
+
# For Windows
|
48 |
+
python -m venv venv
|
49 |
+
.\venv\Scripts\activate
|
50 |
+
```
|
51 |
+
|
52 |
+
3. **Install Requirements:**
|
53 |
+
Make sure you have a `requirements.txt` file in the repository root containing all necessary packages (e.g., `gradio`, `torch`, `transformers`, `soundfile`, etc.).
|
54 |
+
```bash
|
55 |
+
pip install -r requirements.txt
|
56 |
+
```
|
57 |
+
|
58 |
+
## Running the Application
|
59 |
+
|
60 |
+
Once the setup is complete, run the Gradio application:
|
61 |
+
|
62 |
+
```bash
|
63 |
+
python app.py
|
64 |
+
```
|
65 |
+
|
66 |
+
This will launch the Gradio interface, typically accessible at http://127.0.0.1:7860 in your web browser. The console output will provide the exact URL.
|
67 |
+
|
68 |
+
## Training Your Own Custom HuBERT Model
|
69 |
+
|
70 |
+
If you want to train your own HuBERT model for voice cloning:
|
71 |
+
|
72 |
+
1. Generate Dataset:
|
73 |
+
|
74 |
+
- Use the generate_audio_semantic_dataset.py script.
|
75 |
+
|
76 |
+
2. Train the Model:
|
77 |
+
|
78 |
+
- Use the train_hubert.py script.
|
79 |
+
|
80 |
+
- This script takes the generated dataset (audio paths and semantic token paths) to fine-tune a HuBERT model with a Transformer decoder head.
|
81 |
+
|
82 |
+
- Configure training parameters (batch size, learning rate, epochs, output directory) within the script or via command-line arguments (if implemented).
|
83 |
+
|
84 |
+
## License
|
85 |
+
|
86 |
+
MIT
|
87 |
+
|
88 |
+
## Acknowledgements
|
89 |
+
|
90 |
+
- Suno AI, they trained the models
|
91 |
+
|
92 |
+
- gitmylo, inspired me to use HuBERT to predict semantic tokens from audio
|
app.py
ADDED
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from config import *
|
3 |
+
from event_handlers import *
|
4 |
+
|
5 |
+
|
6 |
+
# --- Gradio UI Definition ---
|
7 |
+
# theme = gr.themes.Default(primary_hue=gr.themes.colors.blue).set()
|
8 |
+
theme = gr.themes.Ocean(primary_hue=gr.themes.colors.blue).set()
|
9 |
+
|
10 |
+
with gr.Blocks(
|
11 |
+
theme=theme,
|
12 |
+
title="grAudio",
|
13 |
+
css=".gradio-container { max-width: 95% !important; }",
|
14 |
+
) as app:
|
15 |
+
|
16 |
+
# --- Global State ---
|
17 |
+
initial_audio_list = load_existing_audio()
|
18 |
+
audio_list_state = gr.State(value=initial_audio_list)
|
19 |
+
newly_generated_state = gr.State([])
|
20 |
+
# State to store the index of the selected row in the DataFrame
|
21 |
+
selected_index_state = gr.State(-1) # -1 means nothing selected
|
22 |
+
|
23 |
+
# --- UI Layout ---
|
24 |
+
gr.Markdown("# Generate Audio from text")
|
25 |
+
with gr.Row(equal_height=False):
|
26 |
+
# --- Column 1: Configuration (Left) ---
|
27 |
+
with gr.Column(scale=2, min_width=350):
|
28 |
+
gr.Markdown("### Generation Configuration")
|
29 |
+
with gr.Accordion("Batch size & Temperatures", open=True):
|
30 |
+
batch_size_number = gr.Number(
|
31 |
+
value=1,
|
32 |
+
label="Seed",
|
33 |
+
minimum=0,
|
34 |
+
step=1,
|
35 |
+
scale=1,
|
36 |
+
)
|
37 |
+
semantic_temp_slider = gr.Slider(
|
38 |
+
0.1, 1.0, value=0.7, step=0.1, label="Semantic Temp"
|
39 |
+
)
|
40 |
+
coarse_temp_slider = gr.Slider(
|
41 |
+
0.1, 1.0, value=0.7, step=0.1, label="Coarse Temp"
|
42 |
+
)
|
43 |
+
fine_temp_slider = gr.Slider(
|
44 |
+
0.1, 1.0, value=0.7, step=0.1, label="Fine Temp"
|
45 |
+
)
|
46 |
+
with gr.Accordion("Model, Devices", open=True):
|
47 |
+
model_type_dropdown = gr.Dropdown(
|
48 |
+
choices=["small", "large"], value="small", label="Model Type"
|
49 |
+
)
|
50 |
+
|
51 |
+
available_devices, best_device = get_available_torch_devices()
|
52 |
+
device_dropdown = gr.Dropdown(
|
53 |
+
choices=available_devices, value=best_device, label="Device"
|
54 |
+
)
|
55 |
+
with gr.Accordion("Voice Prompt", open=True):
|
56 |
+
prompt_dropdown = gr.Dropdown(
|
57 |
+
choices=get_available_prompts(),
|
58 |
+
label="Select Voice Prompt",
|
59 |
+
info="Optional",
|
60 |
+
multiselect=False,
|
61 |
+
allow_custom_value=False,
|
62 |
+
)
|
63 |
+
refresh_prompts_btn = gr.Button(
|
64 |
+
"Refresh Prompts", variant="secondary", size="sm"
|
65 |
+
)
|
66 |
+
with gr.Accordion("Create New Voice Prompt", open=False):
|
67 |
+
prompt_audio_upload = gr.File(
|
68 |
+
value=None,
|
69 |
+
file_count="single",
|
70 |
+
label="Upload Audio (.wav, .mp3)",
|
71 |
+
file_types=["audio"],
|
72 |
+
type="filepath",
|
73 |
+
)
|
74 |
+
create_prompt_btn = gr.Button("Create Prompt", variant="secondary")
|
75 |
+
|
76 |
+
# --- Column 2: Text Input & Generate Button (Middle) ---
|
77 |
+
with gr.Column(scale=4, min_width=600):
|
78 |
+
gr.Markdown("### Text Input")
|
79 |
+
text_input_block = gr.Textbox(
|
80 |
+
lines=30,
|
81 |
+
placeholder="If your text includes multiple long sentences, select a voice prompt to have consistent speech.\nDo not use long sentence, split them out to multiple sentences with each less than 15 seconds",
|
82 |
+
label="Text Prompts",
|
83 |
+
)
|
84 |
+
generate_btn = gr.Button("Generate", variant="primary")
|
85 |
+
# --- Column 3: Generated Audio Display (Right) - SIMPLIFIED ---
|
86 |
+
with gr.Column(scale=2, min_width=250):
|
87 |
+
gr.Markdown("### Generated Audio")
|
88 |
+
# DataFrame to display the list
|
89 |
+
audio_dataframe = gr.DataFrame(
|
90 |
+
headers=["File", "Prompt", "Duration (s)"],
|
91 |
+
datatype=["str", "str", "str"],
|
92 |
+
interactive=True, # Allow row selection
|
93 |
+
row_count=(10, "dynamic"), # Show ~10 rows, scroll if more
|
94 |
+
col_count=(3, "fixed"),
|
95 |
+
# value=format_audio_list_for_dataframe(initial_audio_list) # Set initial value via app.load
|
96 |
+
)
|
97 |
+
# Single audio player for the selected item
|
98 |
+
selected_audio_player = gr.Audio(
|
99 |
+
label="Selected Audio",
|
100 |
+
type="filepath",
|
101 |
+
interactive=False, # Only for playback
|
102 |
+
)
|
103 |
+
# Single delete button
|
104 |
+
delete_selected_btn = gr.Button("Delete Selected Audio", variant="stop")
|
105 |
+
|
106 |
+
# --- Event Handling ---
|
107 |
+
|
108 |
+
# 1. Refresh Prompts Button
|
109 |
+
refresh_prompts_btn.click(
|
110 |
+
fn=update_available_prompts, inputs=None, outputs=[prompt_dropdown]
|
111 |
+
)
|
112 |
+
|
113 |
+
# 2. Create Prompt Button
|
114 |
+
create_prompt_btn.click(
|
115 |
+
fn=create_audio_prompt,
|
116 |
+
inputs=[prompt_audio_upload, device_dropdown],
|
117 |
+
outputs=[prompt_dropdown],
|
118 |
+
)
|
119 |
+
|
120 |
+
# 3. Generate Button -> Calls backend -> Outputs to temporary state
|
121 |
+
generate_btn.click(
|
122 |
+
fn=generate_batch_audio,
|
123 |
+
inputs=[
|
124 |
+
text_input_block,
|
125 |
+
semantic_temp_slider,
|
126 |
+
coarse_temp_slider,
|
127 |
+
fine_temp_slider,
|
128 |
+
batch_size_number,
|
129 |
+
model_type_dropdown,
|
130 |
+
device_dropdown,
|
131 |
+
prompt_dropdown,
|
132 |
+
],
|
133 |
+
outputs=[newly_generated_state],
|
134 |
+
)
|
135 |
+
|
136 |
+
# 4. Temporary State Change -> Updates the main audio list state
|
137 |
+
newly_generated_state.change(
|
138 |
+
fn=update_audio_list,
|
139 |
+
inputs=[newly_generated_state, audio_list_state],
|
140 |
+
outputs=[audio_list_state],
|
141 |
+
show_progress="hidden",
|
142 |
+
)
|
143 |
+
|
144 |
+
# 5. Main Audio List State Change -> Updates the DataFrame display
|
145 |
+
# Also clears selection when the list updates.
|
146 |
+
audio_list_state.change(
|
147 |
+
fn=format_audio_list_for_dataframe,
|
148 |
+
inputs=[audio_list_state],
|
149 |
+
outputs=[audio_dataframe],
|
150 |
+
show_progress="hidden",
|
151 |
+
).then( # Chain: after updating dataframe, clear selection player and index
|
152 |
+
fn=lambda: (None, -1), # Function returning values to clear outputs
|
153 |
+
inputs=None,
|
154 |
+
outputs=[selected_audio_player, selected_index_state],
|
155 |
+
show_progress="hidden",
|
156 |
+
queue=False,
|
157 |
+
)
|
158 |
+
|
159 |
+
# 6. DataFrame Row Selection -> Updates the selected index and audio player
|
160 |
+
audio_dataframe.select(
|
161 |
+
fn=handle_row_selection,
|
162 |
+
inputs=[audio_list_state], # Pass the full list state to find the filepath
|
163 |
+
outputs=[
|
164 |
+
selected_audio_player,
|
165 |
+
selected_index_state,
|
166 |
+
],
|
167 |
+
show_progress="hidden",
|
168 |
+
)
|
169 |
+
|
170 |
+
# 7. Delete Selected Button Click -> Calls delete handler
|
171 |
+
delete_selected_btn.click(
|
172 |
+
fn=handle_delete_selected,
|
173 |
+
inputs=[selected_index_state, audio_list_state], # Pass index and list
|
174 |
+
outputs=[
|
175 |
+
audio_list_state, # Update the main list state
|
176 |
+
selected_index_state, # Clear the selected index
|
177 |
+
selected_audio_player, # Clear the audio player
|
178 |
+
],
|
179 |
+
show_progress="hidden",
|
180 |
+
)
|
181 |
+
|
182 |
+
# 8. Initial Load: Populate the DataFrame
|
183 |
+
app.load(
|
184 |
+
fn=format_audio_list_for_dataframe,
|
185 |
+
inputs=[audio_list_state], # Use the initial state value
|
186 |
+
outputs=[audio_dataframe], # Render initial data into the DataFrame
|
187 |
+
)
|
188 |
+
|
189 |
+
|
190 |
+
if __name__ == "__main__":
|
191 |
+
app.launch(debug=True, share=False)
|
config.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
# --- Configuration ---
|
4 |
+
PROMPT_DIR = "./prompts"
|
5 |
+
GENERATED_AUDIO_DIR = "./generated_audio"
|
6 |
+
os.makedirs(PROMPT_DIR, exist_ok=True)
|
7 |
+
os.makedirs(GENERATED_AUDIO_DIR, exist_ok=True)
|
8 |
+
|
9 |
+
# Constants for audio generation
|
10 |
+
DEFAULT_AUDIO_SAMPLE_RATE = 24000
|
11 |
+
DEFAULT_DURATION = 3
|
12 |
+
DEFAULT_FREQ = 440
|
core/__init__.py
ADDED
File without changes
|
core/bark/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from core.bark.generate_audio import *
|
2 |
+
|
3 |
+
from core.bark.encodec import *
|
4 |
+
|
5 |
+
from core.bark.voice_clone import *
|
core/bark/constants.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# original BARK semantic vocab size
|
2 |
+
SEMANTIC_VOCAB_SIZE = 10_000
|
3 |
+
# HuBERT model output vocab size
|
4 |
+
HUBERT_OUTPUT_VOCAB_SIZE = 10_003
|
5 |
+
CODEBOOK_SIZE = 1024
|
6 |
+
N_COARSE_CODEBOOKS = 2
|
7 |
+
COARSE_RATE_HZ = 75
|
8 |
+
COARSE_SEMANTIC_PAD_TOKEN = 12_048
|
9 |
+
COARSE_INFER_TOKEN = 12_050
|
10 |
+
|
11 |
+
# for the BERT model to get semantic tokens from raw texts
|
12 |
+
TEXT_ENCODING_OFFSET = 10_048
|
13 |
+
SEMANTIC_PAD_TOKEN = 10_000
|
14 |
+
TEXT_PAD_TOKEN = 129_595
|
15 |
+
SEMANTIC_INFER_TOKEN = 129_599
|
16 |
+
SEMANTIC_RATE_HZ = 49.9
|
17 |
+
|
18 |
+
N_FINE_CODEBOOKS = 8
|
core/bark/custom_context.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import contextlib
|
2 |
+
import torch
|
3 |
+
import funcy
|
4 |
+
|
5 |
+
|
6 |
+
"""
|
7 |
+
Custom context managers for PyTorch inference operations.
|
8 |
+
|
9 |
+
This module provides context managers for controlling:
|
10 |
+
- CUDA benchmarking settings
|
11 |
+
- Inference mode and gradient calculation
|
12 |
+
- Automatic mixed precision (AMP) casting
|
13 |
+
|
14 |
+
The main context manager `inference_mode()` combines all these settings
|
15 |
+
for optimal inference performance.
|
16 |
+
"""
|
17 |
+
|
18 |
+
|
19 |
+
class InferenceContext:
|
20 |
+
"""
|
21 |
+
Context manager for controlling CUDA benchmarking settings.
|
22 |
+
|
23 |
+
Args:
|
24 |
+
benchmark (bool): Whether to enable cudnn benchmarking. Defaults to False
|
25 |
+
since input lengths may vary in inference scenarios.
|
26 |
+
|
27 |
+
This context manager saves and restores the original cudnn.benchmark setting
|
28 |
+
when entering/exiting the context.
|
29 |
+
"""
|
30 |
+
|
31 |
+
def __init__(self, benchmark=False):
|
32 |
+
# we can't expect inputs to be the same length, so disable benchmarking by default
|
33 |
+
self._chosen_cudnn_benchmark = benchmark
|
34 |
+
self._cudnn_benchmark = None
|
35 |
+
|
36 |
+
def __enter__(self):
|
37 |
+
self._cudnn_benchmark = torch.backends.cudnn.benchmark
|
38 |
+
torch.backends.cudnn.benchmark = self._chosen_cudnn_benchmark
|
39 |
+
|
40 |
+
def __exit__(self, exc_type, exc_value, exc_traceback):
|
41 |
+
torch.backends.cudnn.benchmark = self._cudnn_benchmark
|
42 |
+
|
43 |
+
|
44 |
+
if (
|
45 |
+
torch.cuda.is_available()
|
46 |
+
and hasattr(torch.cuda, "amp")
|
47 |
+
and hasattr(torch.cuda.amp, "autocast")
|
48 |
+
and hasattr(torch.cuda, "is_bf16_supported")
|
49 |
+
and torch.cuda.is_bf16_supported()
|
50 |
+
):
|
51 |
+
autocast = funcy.partial(
|
52 |
+
torch.amp.autocast, dtype=torch.bfloat16, device_type="cuda"
|
53 |
+
)
|
54 |
+
"""Context manager for automatic mixed precision (AMP) using bfloat16 where supported."""
|
55 |
+
else:
|
56 |
+
|
57 |
+
@contextlib.contextmanager
|
58 |
+
def autocast():
|
59 |
+
"""No-op autocast context manager when bfloat16 is not supported."""
|
60 |
+
yield
|
61 |
+
|
62 |
+
|
63 |
+
@contextlib.contextmanager
|
64 |
+
def inference_mode():
|
65 |
+
"""
|
66 |
+
Combined context manager for optimal inference performance.
|
67 |
+
|
68 |
+
Combines:
|
69 |
+
- CUDA benchmarking control
|
70 |
+
- PyTorch inference mode
|
71 |
+
- Disabled gradient calculation
|
72 |
+
- Automatic mixed precision casting (where supported)
|
73 |
+
|
74 |
+
Usage:
|
75 |
+
with inference_mode():
|
76 |
+
# inference operations here
|
77 |
+
"""
|
78 |
+
with InferenceContext(), torch.inference_mode(), torch.no_grad(), autocast():
|
79 |
+
yield
|
core/bark/encodec.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
from encodec import EncodecModel
|
5 |
+
from encodec.utils import convert_audio
|
6 |
+
from core.memory import model_manager, ModelEnum, env
|
7 |
+
from core.bark.custom_context import inference_mode
|
8 |
+
|
9 |
+
|
10 |
+
def encodec_decode_fine_tokens_to_audio(fine_tokens: torch.Tensor) -> np.ndarray:
|
11 |
+
"""
|
12 |
+
expecting fine_tokens shape [codebook_size, timestep], concretely [8, 75*duration_in_sec]
|
13 |
+
Decode the given fine_tokens using the Encodec's decoder
|
14 |
+
Returns the audio sample array as an np.ndarray
|
15 |
+
Returns
|
16 |
+
np.ndarray of shape (B, C, T), C = 1 for mono audio
|
17 |
+
"""
|
18 |
+
model_info = ModelEnum.ENCODEC24k.value
|
19 |
+
|
20 |
+
model_wrapper = model_manager.get_model(model_info)
|
21 |
+
model: EncodecModel = model_wrapper.model
|
22 |
+
|
23 |
+
device = next(model.parameters()).device
|
24 |
+
|
25 |
+
input_tensor = fine_tokens.transpose(0, 1).to(device)
|
26 |
+
|
27 |
+
emb = model.quantizer.decode(input_tensor)
|
28 |
+
|
29 |
+
output: torch.Tensor = model.decoder(emb)
|
30 |
+
audio_arr = output.detach().cpu().numpy()
|
31 |
+
|
32 |
+
del input_tensor, emb, output
|
33 |
+
|
34 |
+
return audio_arr
|
35 |
+
|
36 |
+
|
37 |
+
def encodec_encode_audio(
|
38 |
+
audio_sample: torch.Tensor, audio_sample_rate: int
|
39 |
+
) -> torch.Tensor:
|
40 |
+
"""
|
41 |
+
Encode the given audio sample using the encodec model
|
42 |
+
audio_sample expected shape: (channels, sample)
|
43 |
+
|
44 |
+
Returns codes as a tensor shape [n_q, T]
|
45 |
+
where n_q typically is 8 and T is the compressed time step dimension (75 per second for 24khz model)
|
46 |
+
"""
|
47 |
+
model_wrapper = model_manager.get_model(ModelEnum.ENCODEC24k.value)
|
48 |
+
model: EncodecModel = model_wrapper.model
|
49 |
+
|
50 |
+
device = next(model.parameters()).device
|
51 |
+
|
52 |
+
wav = convert_audio(
|
53 |
+
audio_sample, audio_sample_rate, model.sample_rate, model.channels
|
54 |
+
)
|
55 |
+
wav = wav.unsqueeze(0).float().to(device)
|
56 |
+
|
57 |
+
# Extract discrete codes from EnCodec
|
58 |
+
with inference_mode():
|
59 |
+
encoded_frames = model.encode(wav)
|
60 |
+
|
61 |
+
codes = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1) # [B, n_q, T]
|
62 |
+
|
63 |
+
return codes[0, :, :]
|
core/bark/generate_audio.py
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import logging
|
3 |
+
from typing_extensions import Union, List
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
from dataclasses import asdict
|
7 |
+
|
8 |
+
from core.bark.generate_semantic import generate_semantic_tokens_from_text
|
9 |
+
from core.bark.generate_coarse import generate_coarse_tokens_from_semantic
|
10 |
+
from core.bark.generate_fine import generate_fine_tokens_from_coarse
|
11 |
+
|
12 |
+
|
13 |
+
from core.data_model.bark import BarkPrompt, BarkGenerationConfig
|
14 |
+
from core.bark.encodec import encodec_decode_fine_tokens_to_audio
|
15 |
+
from core.bark.constants import SEMANTIC_PAD_TOKEN, SEMANTIC_RATE_HZ
|
16 |
+
|
17 |
+
logging.basicConfig(
|
18 |
+
level=logging.INFO,
|
19 |
+
format="%(asctime)s - %(levelname)s - %(message)s",
|
20 |
+
handlers=[logging.StreamHandler(sys.stdout)],
|
21 |
+
)
|
22 |
+
logger = logging.getLogger(__name__)
|
23 |
+
|
24 |
+
|
25 |
+
def generate_audio(
|
26 |
+
texts: List[str],
|
27 |
+
prompt: Union[BarkPrompt, None] = None,
|
28 |
+
generation_config: BarkGenerationConfig = None,
|
29 |
+
silent: bool = False,
|
30 |
+
) -> List[np.ndarray]:
|
31 |
+
"""
|
32 |
+
Generate audio from text with an optional audio prompt
|
33 |
+
Args:
|
34 |
+
text (str): Input text to generate audio. Must be non-empty.
|
35 |
+
num_gen (int): number of audio to generate per text
|
36 |
+
prompt (Union[str, None]): optional path to a prompt file of type .npz that will be used as the audio prompt
|
37 |
+
generation_config: configurations to generate audio
|
38 |
+
|
39 |
+
"""
|
40 |
+
if prompt is not None:
|
41 |
+
semantic_prompt = prompt.semantic_prompt if prompt is not None else None
|
42 |
+
# if len(semantic_prompt.shape) == 2:
|
43 |
+
# semantic_prompt = semantic_prompt[0, :]
|
44 |
+
assert (
|
45 |
+
len(semantic_prompt.shape) == 1
|
46 |
+
), "expecting semantic prompt as a 1D array"
|
47 |
+
else:
|
48 |
+
semantic_prompt = None
|
49 |
+
|
50 |
+
if generation_config is None:
|
51 |
+
logger.info("using BARK default generation config")
|
52 |
+
generation_config = BarkGenerationConfig()
|
53 |
+
|
54 |
+
semantic_tokens = generate_semantic_tokens_from_text(
|
55 |
+
texts,
|
56 |
+
semantic_prompt,
|
57 |
+
**asdict(generation_config),
|
58 |
+
silent=silent,
|
59 |
+
)
|
60 |
+
|
61 |
+
# because we generate audio in batch, all audios in one batch have the same length
|
62 |
+
# of the longest audio. We need to remove the random section of the shorter audio
|
63 |
+
# after it has ended
|
64 |
+
|
65 |
+
# coarse token generation
|
66 |
+
coarse_tokens = generate_coarse_tokens_from_semantic(
|
67 |
+
semantic_tokens, prompt, **asdict(generation_config), silent=silent
|
68 |
+
)
|
69 |
+
|
70 |
+
# fine token generation
|
71 |
+
fine_tokens = generate_fine_tokens_from_coarse(
|
72 |
+
coarse_tokens=coarse_tokens,
|
73 |
+
history_prompt=prompt,
|
74 |
+
temperature=generation_config.generate_fine_temperature,
|
75 |
+
use_small_model=generation_config.use_small_model,
|
76 |
+
silent=silent,
|
77 |
+
)
|
78 |
+
|
79 |
+
# decoding the codes
|
80 |
+
audio_wave = encodec_decode_fine_tokens_to_audio(fine_tokens)
|
81 |
+
assert (
|
82 |
+
len(audio_wave.shape) == 3
|
83 |
+
), f"expecting audio tensor of shape (B, C, T), received {audio_wave.shape}"
|
84 |
+
|
85 |
+
audio_wave = audio_wave.squeeze(1) # squeeze the channel dimension
|
86 |
+
res = remove_padded_segment_from_audio(audio_wave, semantic_tokens.cpu().numpy())
|
87 |
+
return res
|
88 |
+
|
89 |
+
|
90 |
+
def remove_padded_segment_from_audio(
|
91 |
+
audio_wave: np.ndarray, semantic_tokens: np.ndarray, audio_sample_rate: int = 24000
|
92 |
+
) -> List[np.ndarray]:
|
93 |
+
# Because the semantic token tensor's time step dimension is of the longest audio in the sample
|
94 |
+
# all the remaining audio have shorter length would have random sound after its end
|
95 |
+
# we will change the values of coarse_token tensor of shorter audios at positions after it end
|
96 |
+
# to avoid random sound in the generated results
|
97 |
+
# SEMANTIC_PAD_TOKEN is also the end of sentence token
|
98 |
+
# this function assume audio_wave has shape (batch, T)
|
99 |
+
assert (
|
100 |
+
len(audio_wave.shape) == 2
|
101 |
+
), f"expecting ndarray of shape (B, T), received {audio_wave.shape}"
|
102 |
+
mask = semantic_tokens == SEMANTIC_PAD_TOKEN
|
103 |
+
semantic_eos_indices = np.argmax(mask.astype(np.int32), axis=1) # Shape [batch]
|
104 |
+
wave_eos_indices: np.ndarray = semantic_eos_indices * (
|
105 |
+
audio_sample_rate / SEMANTIC_RATE_HZ
|
106 |
+
)
|
107 |
+
wave_eos_indices = wave_eos_indices.astype(np.int32)
|
108 |
+
res = []
|
109 |
+
for wave_index in range(audio_wave.shape[0]):
|
110 |
+
if wave_eos_indices[wave_index] == 0:
|
111 |
+
# zero means this audio is the longest one in the batch and there is no need to cut the padded segment
|
112 |
+
res.append(audio_wave[wave_index])
|
113 |
+
continue
|
114 |
+
start_padding_index = wave_eos_indices[wave_index]
|
115 |
+
res.append(audio_wave[wave_index, :start_padding_index])
|
116 |
+
|
117 |
+
return res
|
core/bark/generate_audio_semantic_dataset.py
ADDED
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List
|
2 |
+
import numpy as np
|
3 |
+
from tqdm import tqdm
|
4 |
+
from dataclasses import asdict
|
5 |
+
from core.bark.generate_semantic import generate_semantic_tokens_from_text
|
6 |
+
from core.bark.generate_coarse import generate_coarse_tokens_from_semantic
|
7 |
+
from core.bark.generate_fine import generate_fine_tokens_from_coarse
|
8 |
+
from core.bark.encodec import encodec_decode_fine_tokens_to_audio
|
9 |
+
from core.bark.generate_audio import remove_padded_segment_from_audio
|
10 |
+
from core.data_model import WavSemantic, WavSemanticDataset, BarkGenerationConfig
|
11 |
+
from core.bark.constants import SEMANTIC_PAD_TOKEN
|
12 |
+
|
13 |
+
|
14 |
+
def generate_wav_semantic_dataset(
|
15 |
+
text_file_path: str,
|
16 |
+
generation_config: BarkGenerationConfig,
|
17 |
+
batch_size: int = 16,
|
18 |
+
silent: bool = False,
|
19 |
+
save_path: str = "./dataset",
|
20 |
+
save_data_as_raw_audio: bool = True,
|
21 |
+
) -> None:
|
22 |
+
"""
|
23 |
+
Generate a dataset of (wav, semantic_tokens) for training a model to predict semantic tokens from audio
|
24 |
+
|
25 |
+
Args
|
26 |
+
text_file_path: path to the text file that will be used to generate audio data
|
27 |
+
generation_config: the config used to generate data
|
28 |
+
batch_size: batch size when generate data
|
29 |
+
bark_model_type: either `large` or `small`, the coarse and fine model variant that will be used to generate audio
|
30 |
+
max_token_per_example: a criteria to limit the length of an example from text. The text will be tokenized using a BERT tokenizer,
|
31 |
+
and the tokenized text will be truncated to not exceed this length
|
32 |
+
save_path: path to save the generated dataset
|
33 |
+
save_data_as_raw_audio: if True, waves will be saved as raw audio, otherwise it will be saved as compressed .npz file
|
34 |
+
"""
|
35 |
+
texts = read_text_file(text_file_path)
|
36 |
+
assert len(texts) > 0, "empty text data"
|
37 |
+
|
38 |
+
mini_batches = [texts[i : i + batch_size] for i in range(0, len(texts), batch_size)]
|
39 |
+
progress_bar = tqdm(
|
40 |
+
total=len(mini_batches), disable=silent, desc="Generating wav-semantic dataset"
|
41 |
+
)
|
42 |
+
for batch in mini_batches:
|
43 |
+
semantic_tokens = generate_semantic_tokens_from_text(
|
44 |
+
texts=batch, semantic_prompt=None, silent=True, **asdict(generation_config)
|
45 |
+
)
|
46 |
+
|
47 |
+
coarse = generate_coarse_tokens_from_semantic(
|
48 |
+
semantic_tokens=semantic_tokens,
|
49 |
+
history_prompt=None,
|
50 |
+
silent=True,
|
51 |
+
**asdict(generation_config)
|
52 |
+
)
|
53 |
+
|
54 |
+
fine = generate_fine_tokens_from_coarse(
|
55 |
+
coarse_tokens=coarse,
|
56 |
+
history_prompt=None,
|
57 |
+
temperature=generation_config.generate_fine_temperature,
|
58 |
+
use_small_model=generation_config.use_small_model,
|
59 |
+
silent=True,
|
60 |
+
)
|
61 |
+
|
62 |
+
# generate audio waves from the fine tokens
|
63 |
+
waves = encodec_decode_fine_tokens_to_audio(fine)
|
64 |
+
# remove the channel dimension
|
65 |
+
waves = waves.squeeze(1)
|
66 |
+
|
67 |
+
waves = remove_padded_segment_from_audio(waves, semantic_tokens.cpu().numpy())
|
68 |
+
|
69 |
+
save_semantic_wave_data(
|
70 |
+
batch,
|
71 |
+
waves,
|
72 |
+
semantic_tokens.detach().cpu().numpy(),
|
73 |
+
24000,
|
74 |
+
generation_config,
|
75 |
+
save_path,
|
76 |
+
save_data_as_raw_audio,
|
77 |
+
)
|
78 |
+
|
79 |
+
progress_bar.update(1)
|
80 |
+
del semantic_tokens, coarse, fine, waves
|
81 |
+
|
82 |
+
|
83 |
+
def save_semantic_wave_data(
|
84 |
+
texts: List[str],
|
85 |
+
waves: List[np.ndarray],
|
86 |
+
semantic_tokens: np.ndarray,
|
87 |
+
sample_rate: int,
|
88 |
+
generation_config: BarkGenerationConfig,
|
89 |
+
save_path: str,
|
90 |
+
save_raw_audio: bool,
|
91 |
+
) -> None:
|
92 |
+
"""
|
93 |
+
Save the given data as a WaveSemantic dataset
|
94 |
+
"""
|
95 |
+
examples = []
|
96 |
+
assert (
|
97 |
+
len(texts) == len(waves) == semantic_tokens.shape[0]
|
98 |
+
), "unexpected array length"
|
99 |
+
|
100 |
+
model_type = "small" if generation_config.use_small_model else "large"
|
101 |
+
|
102 |
+
# remove the padding tokens at the end of the semantic sequences
|
103 |
+
mask = semantic_tokens == SEMANTIC_PAD_TOKEN
|
104 |
+
semantic_padding_indices = np.argmax(mask.astype(np.int32), axis=1)
|
105 |
+
|
106 |
+
for i, (text, padding_index) in enumerate(zip(texts, semantic_padding_indices)):
|
107 |
+
if padding_index == 0:
|
108 |
+
padding_index = len(semantic_tokens[i])
|
109 |
+
example = WavSemantic(text, waves[i], semantic_tokens[i, :padding_index])
|
110 |
+
examples.append(example)
|
111 |
+
|
112 |
+
dataset = WavSemanticDataset(sample_rate, generation_config, model_type, examples)
|
113 |
+
|
114 |
+
dataset.save(save_path, save_raw_audio)
|
115 |
+
|
116 |
+
|
117 |
+
def read_text_file(path: str) -> List[str]:
|
118 |
+
with open(path, "r") as file:
|
119 |
+
lines = file.readlines()
|
120 |
+
# Remove newline characters
|
121 |
+
lines = [line.strip() for line in lines]
|
122 |
+
return lines
|
core/bark/generate_coarse.py
ADDED
@@ -0,0 +1,385 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from tqdm import tqdm
|
5 |
+
from typing_extensions import Optional, Union, Tuple
|
6 |
+
|
7 |
+
from core.bark.constants import *
|
8 |
+
from core.model.bark import GPT
|
9 |
+
from core.data_model.bark import BarkPrompt
|
10 |
+
from core.bark.custom_context import inference_mode
|
11 |
+
|
12 |
+
from core.memory import model_manager, ModelEnum, env
|
13 |
+
|
14 |
+
# number of coarse tokens per one semantic token for one second
|
15 |
+
num_coarse_per_semantic = (COARSE_RATE_HZ / SEMANTIC_RATE_HZ) * N_COARSE_CODEBOOKS
|
16 |
+
|
17 |
+
|
18 |
+
def generate_coarse_tokens_from_semantic(
|
19 |
+
semantic_tokens: torch.Tensor,
|
20 |
+
history_prompt: Union[BarkPrompt, None] = None,
|
21 |
+
generate_coarse_temperature: Union[float, None] = 0.6,
|
22 |
+
coarse_top_k: Union[int, None] = None,
|
23 |
+
coarse_top_p: Union[float, None] = None,
|
24 |
+
silent: bool = False,
|
25 |
+
max_coarse_history: int = 630,
|
26 |
+
sliding_window_length: int = 60,
|
27 |
+
use_kv_caching: bool = True,
|
28 |
+
use_small_model: bool = False,
|
29 |
+
**kwargs,
|
30 |
+
) -> torch.Tensor:
|
31 |
+
# Validate inputs
|
32 |
+
_validate_semantic_tokens(semantic_tokens)
|
33 |
+
_validate_history_prompt(history_prompt)
|
34 |
+
|
35 |
+
assert (
|
36 |
+
60 <= max_coarse_history <= 630
|
37 |
+
), "max_coarse_history must be between 60 and 630"
|
38 |
+
assert (
|
39 |
+
max_coarse_history + sliding_window_length <= 1024 - 256
|
40 |
+
), "Context exceeds model limit"
|
41 |
+
|
42 |
+
# align the number of semantic history token with the given max_coarse_history
|
43 |
+
max_semantic_history = int(max_coarse_history / num_coarse_per_semantic)
|
44 |
+
|
45 |
+
# align the length of the provided semantic and coarse history
|
46 |
+
semantic_history, coarse_history = _process_history_prompt(
|
47 |
+
history_prompt, max_semantic_history, num_coarse_per_semantic
|
48 |
+
)
|
49 |
+
|
50 |
+
# Load coarse model
|
51 |
+
coarse_model_info = (
|
52 |
+
ModelEnum.BARK_COARSE_SMALL.value
|
53 |
+
if use_small_model
|
54 |
+
else ModelEnum.BARK_COARSE.value
|
55 |
+
)
|
56 |
+
model_wrapper = model_manager.get_model(coarse_model_info)
|
57 |
+
model: GPT = model_wrapper.model
|
58 |
+
assert isinstance(model, GPT), "unexpected model type"
|
59 |
+
|
60 |
+
# total_steps is the number of coarse tokens the model need to predict
|
61 |
+
total_steps = int(
|
62 |
+
np.floor(semantic_tokens.size(1) * num_coarse_per_semantic / N_COARSE_CODEBOOKS)
|
63 |
+
* N_COARSE_CODEBOOKS
|
64 |
+
)
|
65 |
+
assert (
|
66 |
+
total_steps > 0 and total_steps % N_COARSE_CODEBOOKS == 0
|
67 |
+
), "Invalid step count"
|
68 |
+
|
69 |
+
batch, T = semantic_tokens.size()
|
70 |
+
# expand the semantic history at the batch dimension to match with the semantic_tokens tensor's batch size
|
71 |
+
# for the concatenation
|
72 |
+
semantic_history = semantic_history[None].expand((batch, -1))
|
73 |
+
full_semantic = torch.hstack([semantic_history, semantic_tokens]).to(torch.int32)
|
74 |
+
base_semantic_index = semantic_history.size(1)
|
75 |
+
|
76 |
+
# Generate coarse tokens
|
77 |
+
with inference_mode():
|
78 |
+
generated_coarse = _generate_coarse_with_sliding_window(
|
79 |
+
model,
|
80 |
+
full_semantic,
|
81 |
+
coarse_history,
|
82 |
+
total_steps,
|
83 |
+
base_semantic_index,
|
84 |
+
max_semantic_history,
|
85 |
+
num_coarse_per_semantic,
|
86 |
+
generate_coarse_temperature,
|
87 |
+
coarse_top_k,
|
88 |
+
coarse_top_p,
|
89 |
+
silent,
|
90 |
+
max_coarse_history,
|
91 |
+
sliding_window_length,
|
92 |
+
use_kv_caching,
|
93 |
+
)
|
94 |
+
|
95 |
+
# remove the history prompt from the generated tokens
|
96 |
+
generated_coarse = generated_coarse[:, coarse_history.size(0) :]
|
97 |
+
assert generated_coarse.size(1) == total_steps, "Generated length mismatch"
|
98 |
+
|
99 |
+
# Reshape and adjust coarse codes
|
100 |
+
B, L = generated_coarse.shape
|
101 |
+
# Broadcasting subtracts from all elements
|
102 |
+
coarse_output = (
|
103 |
+
generated_coarse.reshape(B, -1, N_COARSE_CODEBOOKS).transpose(1, 2)
|
104 |
+
- SEMANTIC_VOCAB_SIZE
|
105 |
+
)
|
106 |
+
|
107 |
+
for codebook_idx in range(1, N_COARSE_CODEBOOKS):
|
108 |
+
coarse_output[:, codebook_idx, :] -= codebook_idx * CODEBOOK_SIZE
|
109 |
+
|
110 |
+
return coarse_output
|
111 |
+
|
112 |
+
|
113 |
+
def _validate_semantic_tokens(semantic_tokens: torch.Tensor) -> None:
|
114 |
+
"""
|
115 |
+
Validate the input semantic tokens tensor.
|
116 |
+
|
117 |
+
Args:
|
118 |
+
semantic_tokens: Tensor of semantic tokens (1D).
|
119 |
+
|
120 |
+
Raises:
|
121 |
+
AssertionError: If the tensor does not meet expected criteria.
|
122 |
+
"""
|
123 |
+
assert isinstance(
|
124 |
+
semantic_tokens, torch.Tensor
|
125 |
+
), "Semantic tokens must be a torch.Tensor"
|
126 |
+
assert semantic_tokens.dim() == 2, "Semantic tokens must be 2D"
|
127 |
+
assert semantic_tokens.size(1) > 0, "Semantic tokens tensor cannot be empty"
|
128 |
+
assert semantic_tokens.min() >= 0, "Semantic tokens must be non-negative"
|
129 |
+
assert (
|
130 |
+
semantic_tokens.max() <= SEMANTIC_VOCAB_SIZE
|
131 |
+
), "Semantic tokens exceed vocab size"
|
132 |
+
|
133 |
+
|
134 |
+
def _validate_history_prompt(history_prompt: Union[BarkPrompt, None]) -> None:
|
135 |
+
"""
|
136 |
+
Validate the history prompt if provided.
|
137 |
+
|
138 |
+
Args:
|
139 |
+
history_prompt: BarkPrompt object or None.
|
140 |
+
|
141 |
+
Raises:
|
142 |
+
AssertionError: If the prompt does not meet expected criteria.
|
143 |
+
"""
|
144 |
+
if history_prompt is None:
|
145 |
+
return
|
146 |
+
|
147 |
+
assert isinstance(
|
148 |
+
history_prompt, BarkPrompt
|
149 |
+
), "History prompt must be a BarkPrompt object"
|
150 |
+
semantic = history_prompt.semantic_prompt
|
151 |
+
coarse = history_prompt.coarse_prompt
|
152 |
+
|
153 |
+
assert (
|
154 |
+
isinstance(semantic, torch.Tensor) and semantic.dim() == 1
|
155 |
+
), "Semantic prompt must be 1D tensor"
|
156 |
+
assert (
|
157 |
+
semantic.size(0) > 0
|
158 |
+
and semantic.min() >= 0
|
159 |
+
and semantic.max() <= SEMANTIC_VOCAB_SIZE - 1
|
160 |
+
)
|
161 |
+
assert (
|
162 |
+
isinstance(coarse, torch.Tensor) and coarse.dim() == 2
|
163 |
+
), "Coarse prompt must be 2D tensor"
|
164 |
+
assert (
|
165 |
+
coarse.shape[0] == N_COARSE_CODEBOOKS
|
166 |
+
), "Coarse prompt must have correct number of codebooks"
|
167 |
+
assert coarse.min() >= 0 and coarse.max() <= CODEBOOK_SIZE - 1
|
168 |
+
|
169 |
+
|
170 |
+
def _process_history_prompt(
|
171 |
+
history_prompt: Union[BarkPrompt, None],
|
172 |
+
max_semantic_history: int,
|
173 |
+
coarse_to_semantic_ratio: float,
|
174 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
175 |
+
"""
|
176 |
+
Process the history prompt into semantic and coarse history tensors.
|
177 |
+
Trim on the left (keep the right most tokens)
|
178 |
+
Args:
|
179 |
+
history_prompt: BarkPrompt object or None.
|
180 |
+
max_semantic_history: Maximum number of semantic history tokens.
|
181 |
+
coarse_to_semantic_ratio: Ratio of coarse to semantic token rates.
|
182 |
+
|
183 |
+
Returns:
|
184 |
+
Tuple[semantic_history, coarse_history]: Processed history tensors.
|
185 |
+
"""
|
186 |
+
if history_prompt is None:
|
187 |
+
return torch.tensor(
|
188 |
+
[], dtype=torch.int32, device=torch.device(env.DEVICE)
|
189 |
+
), torch.tensor([], dtype=torch.int32, device=torch.device(env.DEVICE))
|
190 |
+
|
191 |
+
semantic_history = history_prompt.semantic_prompt
|
192 |
+
coarse_history = history_prompt.coarse_prompt
|
193 |
+
|
194 |
+
# Add offset then "ravel("F")" flatten
|
195 |
+
coarse_history = _add_codebook_offset(coarse_history, CODEBOOK_SIZE)
|
196 |
+
coarse_history_flat = coarse_history.T.flatten() + SEMANTIC_VOCAB_SIZE
|
197 |
+
|
198 |
+
# Trim histories to fit max length
|
199 |
+
n_semantic_hist = min(
|
200 |
+
max_semantic_history,
|
201 |
+
semantic_history.size(0) - semantic_history.size(0) % 2, # Ensure even length
|
202 |
+
int(coarse_history_flat.size(0) // coarse_to_semantic_ratio),
|
203 |
+
)
|
204 |
+
n_coarse_hist = int(round(n_semantic_hist * coarse_to_semantic_ratio))
|
205 |
+
|
206 |
+
semantic_history = semantic_history[-n_semantic_hist:].to(torch.int32)
|
207 |
+
coarse_history_flat = coarse_history_flat[-n_coarse_hist:].to(torch.int32)
|
208 |
+
coarse_history_flat = coarse_history_flat[:-2] # Original time alignment hack
|
209 |
+
|
210 |
+
return semantic_history, coarse_history_flat
|
211 |
+
|
212 |
+
|
213 |
+
def _add_codebook_offset(x: torch.Tensor, offset: int) -> torch.Tensor:
|
214 |
+
"""
|
215 |
+
x shape (n_codebook, T)
|
216 |
+
n_codebook start from 0 to n, from the second codebook row on we add offset * row_num
|
217 |
+
"""
|
218 |
+
for n in range(1, x.shape[0]):
|
219 |
+
x[n, :] += offset * n
|
220 |
+
return x
|
221 |
+
|
222 |
+
|
223 |
+
def _sample_coarse_token(
|
224 |
+
logits: torch.Tensor,
|
225 |
+
temperature: Union[float, None],
|
226 |
+
top_k: Optional[int],
|
227 |
+
top_p: Optional[float],
|
228 |
+
logit_start_idx: int,
|
229 |
+
) -> torch.Tensor:
|
230 |
+
"""
|
231 |
+
Sample a coarse token from model logits with filtering.
|
232 |
+
|
233 |
+
Args:
|
234 |
+
logits: Model output logits (shape [batch, seq, vocab]).
|
235 |
+
temperature: Sampling temperature for randomness.
|
236 |
+
top_k: Number of top logits to consider, if specified.
|
237 |
+
top_p: Nucleus sampling threshold, if specified.
|
238 |
+
logit_start_idx: Starting index for coarse token logits.
|
239 |
+
|
240 |
+
Returns:
|
241 |
+
torch.Tensor: Sampled token with offset applied (shape [1]).
|
242 |
+
"""
|
243 |
+
relevant_logits = logits[:, 0, logit_start_idx : logit_start_idx + CODEBOOK_SIZE]
|
244 |
+
|
245 |
+
if temperature is None:
|
246 |
+
probs = F.softmax(relevant_logits, dim=-1)
|
247 |
+
next_token = torch.argmax(probs, dim=-1, keepdim=True).to(torch.int32)
|
248 |
+
else:
|
249 |
+
if top_p is not None: # this branch is untested
|
250 |
+
# Optimize with NumPy for top-p filtering,
|
251 |
+
original_device = relevant_logits.device
|
252 |
+
logits_np = relevant_logits.detach().cpu().numpy().astype(np.float32)
|
253 |
+
sorted_indices = np.argsort(logits_np)[::-1]
|
254 |
+
sorted_logits = logits_np[sorted_indices]
|
255 |
+
cumulative_probs = np.cumsum(
|
256 |
+
F.softmax(torch.from_numpy(sorted_logits), dim=-1).numpy()
|
257 |
+
)
|
258 |
+
indices_to_remove = cumulative_probs > top_p
|
259 |
+
indices_to_remove[1:] = indices_to_remove[:-1].copy()
|
260 |
+
indices_to_remove[0] = False
|
261 |
+
logits_np[sorted_indices[indices_to_remove]] = -np.inf
|
262 |
+
relevant_logits = torch.from_numpy(logits_np).to(original_device)
|
263 |
+
|
264 |
+
if top_k is not None:
|
265 |
+
top_values, _ = torch.topk(
|
266 |
+
relevant_logits, min(top_k, relevant_logits.size(-1))
|
267 |
+
)
|
268 |
+
relevant_logits[relevant_logits < top_values[:, [-1]]] = -float("Inf")
|
269 |
+
|
270 |
+
probs = F.softmax(relevant_logits / temperature, dim=-1)
|
271 |
+
next_token = torch.multinomial(probs, num_samples=1).to(torch.int32)
|
272 |
+
return next_token + logit_start_idx
|
273 |
+
|
274 |
+
|
275 |
+
def _generate_coarse_with_sliding_window(
|
276 |
+
model: GPT,
|
277 |
+
full_semantic: torch.Tensor,
|
278 |
+
coarse_history: torch.Tensor,
|
279 |
+
total_steps: int,
|
280 |
+
base_semantic_index: int,
|
281 |
+
max_semantic_history: int,
|
282 |
+
coarse_per_semantic: float,
|
283 |
+
temperature: float,
|
284 |
+
top_k: Optional[int],
|
285 |
+
top_p: Optional[float],
|
286 |
+
silent: bool,
|
287 |
+
max_coarse_history: int,
|
288 |
+
sliding_window_length: int,
|
289 |
+
use_kv_caching: bool,
|
290 |
+
) -> torch.Tensor:
|
291 |
+
"""
|
292 |
+
Generate coarse tokens using a sliding window approach.
|
293 |
+
|
294 |
+
Args:
|
295 |
+
model: GPT model for coarse token generation.
|
296 |
+
full_semantic: 2D tensor of Concatenated semantic history and input tokens.
|
297 |
+
coarse_history: 1D tensor, Initial coarse history tokens.
|
298 |
+
total_steps: Total number of coarse tokens to generate.
|
299 |
+
base_semantic_index: Start index of input semantic tokens.
|
300 |
+
max_semantic_history: Maximum semantic history length.
|
301 |
+
coarse_per_semantic: Coarse-to-semantic token ratio.
|
302 |
+
temperature: Sampling temperature.
|
303 |
+
top_k: Top-k filtering parameter.
|
304 |
+
top_p: Top-p filtering parameter.
|
305 |
+
silent: Suppresses progress bar if True.
|
306 |
+
max_coarse_history: Maximum coarse history length.
|
307 |
+
sliding_window_length: Tokens per window.
|
308 |
+
use_kv_caching: Enables KV caching.
|
309 |
+
|
310 |
+
Returns:
|
311 |
+
torch.Tensor: Generated coarse tokens (1D).
|
312 |
+
"""
|
313 |
+
device = next(model.parameters()).device
|
314 |
+
semantic_tensor = full_semantic.to(device) # Add batch dimension
|
315 |
+
coarse_tensor = (
|
316 |
+
coarse_history[None].expand((semantic_tensor.shape[0], -1)).to(device)
|
317 |
+
)
|
318 |
+
|
319 |
+
window_count = int(np.ceil(total_steps / sliding_window_length))
|
320 |
+
progress_bar = tqdm(
|
321 |
+
total=window_count, disable=silent, desc="Generating coarse tokens"
|
322 |
+
)
|
323 |
+
step_counter = 0 # equivalent to the number of coarse tokens generated so far
|
324 |
+
|
325 |
+
for _ in range(window_count):
|
326 |
+
current_semantic_idx = base_semantic_index + int(
|
327 |
+
round(step_counter / coarse_per_semantic)
|
328 |
+
)
|
329 |
+
|
330 |
+
window_start = max(0, current_semantic_idx - max_semantic_history)
|
331 |
+
semantic_window = semantic_tensor[:, window_start : window_start + 256]
|
332 |
+
semantic_window = F.pad(
|
333 |
+
semantic_window,
|
334 |
+
(0, 256 - semantic_window.shape[-1]),
|
335 |
+
"constant",
|
336 |
+
COARSE_SEMANTIC_PAD_TOKEN,
|
337 |
+
)
|
338 |
+
|
339 |
+
input_tensor = torch.hstack(
|
340 |
+
[
|
341 |
+
semantic_window,
|
342 |
+
torch.tensor([COARSE_INFER_TOKEN], device=device)[None].expand(
|
343 |
+
(semantic_window.shape[0], -1)
|
344 |
+
),
|
345 |
+
coarse_tensor[:, -max_coarse_history:],
|
346 |
+
]
|
347 |
+
)
|
348 |
+
|
349 |
+
kv_cache = None
|
350 |
+
for _ in range(sliding_window_length):
|
351 |
+
if step_counter >= total_steps:
|
352 |
+
break
|
353 |
+
|
354 |
+
is_first_codebook = step_counter % N_COARSE_CODEBOOKS == 0
|
355 |
+
logit_start_idx = (
|
356 |
+
SEMANTIC_VOCAB_SIZE + (1 - int(is_first_codebook)) * CODEBOOK_SIZE
|
357 |
+
)
|
358 |
+
|
359 |
+
model_input = (
|
360 |
+
input_tensor[:, [-1]]
|
361 |
+
if use_kv_caching and kv_cache is not None
|
362 |
+
else input_tensor
|
363 |
+
)
|
364 |
+
logits, kv_cache = model(
|
365 |
+
model_input, use_cache=use_kv_caching, past_kv=kv_cache
|
366 |
+
)
|
367 |
+
next_token = _sample_coarse_token(
|
368 |
+
logits,
|
369 |
+
temperature,
|
370 |
+
top_k,
|
371 |
+
top_p,
|
372 |
+
logit_start_idx,
|
373 |
+
)
|
374 |
+
|
375 |
+
coarse_tensor = torch.cat((coarse_tensor, next_token), dim=1)
|
376 |
+
input_tensor = torch.cat((input_tensor, next_token), dim=1)
|
377 |
+
|
378 |
+
step_counter += 1
|
379 |
+
del logits, next_token
|
380 |
+
|
381 |
+
del input_tensor
|
382 |
+
progress_bar.update(1)
|
383 |
+
|
384 |
+
progress_bar.close()
|
385 |
+
return coarse_tensor
|
core/bark/generate_fine.py
ADDED
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Union
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from tqdm import tqdm
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
from core.data_model.bark import BarkPrompt
|
8 |
+
from core.bark.custom_context import inference_mode
|
9 |
+
from core.model import FineGPT
|
10 |
+
from core.memory import ModelEnum, model_manager
|
11 |
+
from core.bark.constants import *
|
12 |
+
|
13 |
+
|
14 |
+
def generate_fine_tokens_from_coarse(
|
15 |
+
coarse_tokens: torch.Tensor,
|
16 |
+
history_prompt: Union[BarkPrompt, None] = None,
|
17 |
+
temperature: float = 0.5,
|
18 |
+
use_small_model: bool = True,
|
19 |
+
silent: bool = False,
|
20 |
+
) -> torch.Tensor:
|
21 |
+
"""
|
22 |
+
Generate fine-grained audio codes from coarse audio codes using the BARK fine model.
|
23 |
+
|
24 |
+
This function takes coarse tokens (representing a partial set of audio codebooks) and
|
25 |
+
autoregressively predicts the remaining fine tokens, optionally conditioning on a history
|
26 |
+
prompt. The process involves sliding a context window over the sequence, predicting 512
|
27 |
+
timesteps at a time based on a 1024-timestep input.
|
28 |
+
|
29 |
+
Prompt tokens are trim on the left (keep the right most tokens)
|
30 |
+
|
31 |
+
Args:
|
32 |
+
coarse_tokens (torch.Tensor): Coarse audio codes with shape (batch, n_coarse, sequence_length),
|
33 |
+
where n_coarse <= N_FINE_CODEBOOKS - 1 and values are in [0, CODEBOOK_SIZE - 1].
|
34 |
+
history_prompt (BarkPrompt, optional): Historical fine tokens for conditioning, or None.
|
35 |
+
temperature (float): Sampling temperature for fine token prediction; if None, uses argmax.
|
36 |
+
silent (bool): If True, suppresses progress bar output.
|
37 |
+
|
38 |
+
Returns:
|
39 |
+
torch.Tensor: Fine audio codes with shape (N_FINE_CODEBOOKS, sequence_length),
|
40 |
+
matching the input sequence_length.
|
41 |
+
|
42 |
+
Raises:
|
43 |
+
AssertionError: If input validation fails for coarse_tokens or history_prompt.
|
44 |
+
"""
|
45 |
+
# Validate inputs
|
46 |
+
_validate_coarse_tokens(coarse_tokens=coarse_tokens)
|
47 |
+
history_fine_tokens = _validate_and_load_history(history_prompt=history_prompt)
|
48 |
+
batch, n_coarse, sequence_length = coarse_tokens.shape
|
49 |
+
|
50 |
+
# Load the fine model
|
51 |
+
model_info = (
|
52 |
+
ModelEnum.BARK_FINE_SMALL.value
|
53 |
+
if use_small_model
|
54 |
+
else ModelEnum.BARK_FINE.value
|
55 |
+
)
|
56 |
+
model_wrapper = model_manager.get_model(model_info)
|
57 |
+
model: FineGPT = model_wrapper.model
|
58 |
+
assert isinstance(model, FineGPT), "Expected FineGPT model type"
|
59 |
+
device = next(model.parameters()).device
|
60 |
+
coarse_tokens = coarse_tokens.to(device)
|
61 |
+
# stack coarse tokens with padding for remaining codebooks across the codebook dimension
|
62 |
+
# e.g original coarse_token shape (B, 2, T), after vstack shape: (B, 8, T) where codebook size = 8
|
63 |
+
pad_tensor = torch.full(
|
64 |
+
(batch, N_FINE_CODEBOOKS - n_coarse, sequence_length),
|
65 |
+
CODEBOOK_SIZE,
|
66 |
+
dtype=torch.int32,
|
67 |
+
device=device,
|
68 |
+
)
|
69 |
+
|
70 |
+
input_tensor = torch.cat((coarse_tokens, pad_tensor), dim=1)
|
71 |
+
|
72 |
+
# Prepend history if provided. Maximum history time step is 512
|
73 |
+
# this is a horizontal prepend on the left of the previous padded input tensor
|
74 |
+
# output tensor: (8, history_timestep + coarse_timestep), history_timestep <= 512
|
75 |
+
n_history = 0
|
76 |
+
if history_fine_tokens is not None:
|
77 |
+
history_fine_tokens = history_fine_tokens.expand((batch, N_FINE_CODEBOOKS, -1))
|
78 |
+
history_limit = min(history_fine_tokens.shape[-1], 512)
|
79 |
+
history_slice = history_fine_tokens[:, :, -history_limit:].to(
|
80 |
+
device, dtype=torch.int32
|
81 |
+
)
|
82 |
+
input_tensor = torch.cat((history_slice, input_tensor), dim=-1)
|
83 |
+
n_history = history_limit # number of time step dimension in the prompt
|
84 |
+
|
85 |
+
# right Pad if total_length (history_timestep + coarse_timestep) is less than model context (1024)
|
86 |
+
total_length = input_tensor.shape[-1]
|
87 |
+
padding_needed = max(0, 1024 - total_length)
|
88 |
+
if padding_needed > 0:
|
89 |
+
padding = torch.full(
|
90 |
+
(batch, N_FINE_CODEBOOKS, padding_needed),
|
91 |
+
CODEBOOK_SIZE,
|
92 |
+
dtype=torch.int32,
|
93 |
+
device=device,
|
94 |
+
)
|
95 |
+
input_tensor = torch.cat((input_tensor, padding), dim=2)
|
96 |
+
total_length = input_tensor.shape[-1]
|
97 |
+
|
98 |
+
# Calculate number of prediction loops
|
99 |
+
context_window = 1024 # Model's input context size
|
100 |
+
prediction_step = 512 # Number of new timesteps predicted per loop
|
101 |
+
remaining_length = max(0, sequence_length - (context_window - n_history))
|
102 |
+
extra_loops = (remaining_length + prediction_step - 1) // prediction_step
|
103 |
+
n_loops = 1 + extra_loops # Total loops: initial + extra
|
104 |
+
|
105 |
+
# Process sequence in sliding windows
|
106 |
+
input_tensor = input_tensor.transpose(
|
107 |
+
-2, -1
|
108 |
+
) # Shape: (total_length, N_FINE_CODEBOOKS)
|
109 |
+
with inference_mode():
|
110 |
+
for loop_idx in tqdm(
|
111 |
+
range(n_loops), disable=silent, desc="Generating fine tokens"
|
112 |
+
):
|
113 |
+
# Define window boundaries
|
114 |
+
# the last loop, by using window_start = (total_length - context_window),
|
115 |
+
# the input will be: input_tensor[:, -1024:, :], the last context_window timestep of the input
|
116 |
+
window_start = min(
|
117 |
+
loop_idx * prediction_step, total_length - context_window
|
118 |
+
)
|
119 |
+
|
120 |
+
fill_start = min(
|
121 |
+
n_history + loop_idx * prediction_step, total_length - prediction_step
|
122 |
+
)
|
123 |
+
fill_offset = fill_start - window_start
|
124 |
+
window_end = window_start + context_window
|
125 |
+
|
126 |
+
# Extract input window
|
127 |
+
# Shape: (1, 1024, N_FINE_CODEBOOKS)
|
128 |
+
input_window = input_tensor[:, window_start:window_end, :]
|
129 |
+
|
130 |
+
# Predict fine codebooks autoregressively
|
131 |
+
for codebook_idx in range(n_coarse, N_FINE_CODEBOOKS):
|
132 |
+
# Shape: (1, 1024, vocab_size)
|
133 |
+
logits = model(codebook_idx, input_window)
|
134 |
+
if temperature is None:
|
135 |
+
preds = torch.argmax(
|
136 |
+
logits[:, fill_offset:, :CODEBOOK_SIZE], dim=-1
|
137 |
+
)
|
138 |
+
else:
|
139 |
+
scaled_logits = logits[:, :, :CODEBOOK_SIZE] / temperature
|
140 |
+
probs = F.softmax(scaled_logits, dim=-1)
|
141 |
+
probs = probs[:, fill_offset:, :]
|
142 |
+
# Reshape to [2 * N, 1024] for multinomial
|
143 |
+
B, N, C = probs.shape # B=2, N=512-fill_offset, C=1024
|
144 |
+
probs_2d = probs.reshape(-1, C) # Shape: [2 * N, 1024]
|
145 |
+
|
146 |
+
# Perform multinomial sampling
|
147 |
+
# Shape: [2 * N, 1]
|
148 |
+
preds = torch.multinomial(probs_2d, num_samples=1)
|
149 |
+
|
150 |
+
# Reshape back to [2, N] after squeezing
|
151 |
+
preds = preds.squeeze(-1).reshape(B, N)
|
152 |
+
|
153 |
+
input_window[:, fill_offset:, codebook_idx] = preds.to(torch.int32)
|
154 |
+
|
155 |
+
# Update main tensor with predictions
|
156 |
+
fill_length = min(prediction_step, total_length - fill_start)
|
157 |
+
input_tensor[:, fill_start : fill_start + fill_length, codebook_idx] = (
|
158 |
+
input_window[
|
159 |
+
:, fill_offset : fill_offset + fill_length, codebook_idx
|
160 |
+
]
|
161 |
+
)
|
162 |
+
|
163 |
+
# Extract final result, removing history and padding
|
164 |
+
# Shape: (N_FINE_CODEBOOKS, sequence_length)
|
165 |
+
fine_tokens = input_tensor.transpose(-1, -2)[
|
166 |
+
:, :, n_history : n_history + sequence_length
|
167 |
+
]
|
168 |
+
|
169 |
+
# Verify output shape matches input sequence length
|
170 |
+
assert fine_tokens.shape[-1] == sequence_length, "Output length mismatch"
|
171 |
+
|
172 |
+
return fine_tokens
|
173 |
+
|
174 |
+
|
175 |
+
def _validate_coarse_tokens(coarse_tokens: torch.Tensor) -> None:
|
176 |
+
"""Validate coarse token tensor properties."""
|
177 |
+
assert isinstance(
|
178 |
+
coarse_tokens, torch.Tensor
|
179 |
+
), "coarse_tokens must be a torch.Tensor"
|
180 |
+
assert len(coarse_tokens.shape) == 3, "coarse_tokens must be 3D"
|
181 |
+
assert (
|
182 |
+
1 <= coarse_tokens.shape[1] <= N_FINE_CODEBOOKS - 1
|
183 |
+
), "Invalid number of coarse codebooks"
|
184 |
+
assert coarse_tokens.shape[-1] > 0, "Sequence length must be positive"
|
185 |
+
assert (
|
186 |
+
coarse_tokens.min() >= 0 and coarse_tokens.max() <= CODEBOOK_SIZE
|
187 |
+
), "Token values out of range"
|
188 |
+
|
189 |
+
|
190 |
+
def _validate_and_load_history(
|
191 |
+
history_prompt: Union[BarkPrompt, None],
|
192 |
+
) -> Union[torch.Tensor, None]:
|
193 |
+
"""Validate and load history prompt if provided."""
|
194 |
+
if history_prompt is None:
|
195 |
+
return None
|
196 |
+
|
197 |
+
history_fine_tokens = history_prompt.fine_prompt
|
198 |
+
assert isinstance(
|
199 |
+
history_fine_tokens, torch.Tensor
|
200 |
+
), "history_prompt.fine_prompt must be a torch.Tensor"
|
201 |
+
assert len(history_fine_tokens.shape) == 2, "History must be 2D"
|
202 |
+
assert (
|
203 |
+
history_fine_tokens.shape[0] == N_FINE_CODEBOOKS
|
204 |
+
), "History must have all fine codebooks"
|
205 |
+
assert history_fine_tokens.shape[1] > 0, "History must not empty"
|
206 |
+
assert (
|
207 |
+
history_fine_tokens.min() >= 0
|
208 |
+
and history_fine_tokens.max() <= CODEBOOK_SIZE - 1
|
209 |
+
), "History values out of range"
|
210 |
+
return history_fine_tokens
|
core/bark/generate_semantic.py
ADDED
@@ -0,0 +1,361 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Tuple, Optional, Union
|
2 |
+
import re
|
3 |
+
from tqdm import tqdm
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
import torch.nn.functional as F
|
7 |
+
|
8 |
+
from transformers import BertTokenizer
|
9 |
+
|
10 |
+
from core.memory import model_manager, ModelEnum, env
|
11 |
+
from core.bark.custom_context import inference_mode
|
12 |
+
from core.bark.constants import *
|
13 |
+
from core.model import GPT
|
14 |
+
|
15 |
+
SEMANTIC_EOS_TOKEN = 10_000
|
16 |
+
|
17 |
+
|
18 |
+
def generate_semantic_tokens_from_text(
|
19 |
+
texts: List[str],
|
20 |
+
semantic_prompt: Union[torch.Tensor, None] = None,
|
21 |
+
temperature: Union[float, None] = 0.7,
|
22 |
+
semantic_top_k: Union[int, None] = None,
|
23 |
+
semantic_top_p: Union[int, None] = None,
|
24 |
+
min_eos_p: float = 0.2,
|
25 |
+
max_gen_duration_second: Union[float, None] = None,
|
26 |
+
allow_early_stop: bool = True,
|
27 |
+
use_kv_caching: bool = True,
|
28 |
+
use_small_model: bool = True,
|
29 |
+
silent: Union[bool, None] = False,
|
30 |
+
max_token_ids_per_sentence: int = 256,
|
31 |
+
**kwargs,
|
32 |
+
) -> torch.Tensor:
|
33 |
+
# trim white spaces and replace redundant white space characters
|
34 |
+
texts = _preprocess_texts(texts)
|
35 |
+
assert all([len(text) > 0 for text in texts]), f"invalid input text {texts}"
|
36 |
+
|
37 |
+
if semantic_prompt is None:
|
38 |
+
semantic_prompt = torch.tensor([])
|
39 |
+
else:
|
40 |
+
assert isinstance(
|
41 |
+
semantic_prompt, torch.Tensor
|
42 |
+
), f"expecting semantic_prompt of type torch.Tensor, received {type(semantic_prompt)}"
|
43 |
+
assert semantic_prompt.dim() == 1, "expect 1D tensor as semantic_prompt"
|
44 |
+
|
45 |
+
# load the GPT-style model that generate semantic token from text
|
46 |
+
# and the BERT tokenizer to memory
|
47 |
+
text_model_info = (
|
48 |
+
ModelEnum.BARK_TEXT_SMALL.value
|
49 |
+
if use_small_model
|
50 |
+
else ModelEnum.BARK_TEXT.value
|
51 |
+
)
|
52 |
+
|
53 |
+
text_model = model_manager.get_model(text_model_info)
|
54 |
+
assert text_model.model is not None, "text model is None"
|
55 |
+
assert text_model.preprocessor is not None, "tokenizer for the text model is None"
|
56 |
+
|
57 |
+
assert isinstance(
|
58 |
+
text_model.model, GPT
|
59 |
+
), f"expecting model of type GPT, got {type(text_model.model)}"
|
60 |
+
|
61 |
+
assert isinstance(
|
62 |
+
text_model.preprocessor, BertTokenizer
|
63 |
+
), f"expecting preprocessor of type BertTokenizer, got {type(text_model.preprocessor)}"
|
64 |
+
|
65 |
+
model: GPT = text_model.model
|
66 |
+
tokenizer: BertTokenizer = text_model.preprocessor
|
67 |
+
device = next(model.parameters()).device
|
68 |
+
|
69 |
+
# tokenize the given text using the BERT tokenizer
|
70 |
+
token_ids = [tokenizer.encode(text, add_special_tokens=False) for text in texts]
|
71 |
+
|
72 |
+
# for each token_ids of each sentence, append an encoding offset token
|
73 |
+
token_ids = [np.array(sentence) + TEXT_ENCODING_OFFSET for sentence in token_ids]
|
74 |
+
|
75 |
+
# encoded_text's length must has length 256 as from the original implementation
|
76 |
+
# pad to the right if the token_ids of the sentence is shorter, trim on the right if it is longer than 256 tokens
|
77 |
+
token_ids = [
|
78 |
+
trim_or_pad_array(sentence, TEXT_PAD_TOKEN, max_token_ids_per_sentence)
|
79 |
+
for sentence in token_ids
|
80 |
+
]
|
81 |
+
|
82 |
+
token_ids_tensor = torch.vstack(token_ids).to(dtype=torch.int32, device=device)
|
83 |
+
|
84 |
+
# when the token_ids list has one element (batch size = 1), the above cat operation created a 1D tensor
|
85 |
+
# we need to check and make it 2D
|
86 |
+
if len(token_ids_tensor.shape) == 1:
|
87 |
+
token_ids_tensor = token_ids_tensor.unsqueeze(0)
|
88 |
+
# semantic prompt also need to be an array of 256 discrete tokens
|
89 |
+
semantic_prompt = trim_or_pad_array(semantic_prompt, SEMANTIC_PAD_TOKEN, 256)
|
90 |
+
|
91 |
+
# need to replicate the semantic_prompt array to match the shape of the token_ids for concatenation
|
92 |
+
semantic_prompt = (
|
93 |
+
semantic_prompt.unsqueeze(0).expand((token_ids_tensor.shape[0], -1)).to(device)
|
94 |
+
)
|
95 |
+
|
96 |
+
# final input is the concatenation of the token_ids and the semantic tokens array
|
97 |
+
input_tensor = torch.cat(
|
98 |
+
[
|
99 |
+
token_ids_tensor, # shape (batch_size, T)
|
100 |
+
semantic_prompt,
|
101 |
+
torch.tensor([SEMANTIC_INFER_TOKEN], device=device)
|
102 |
+
.unsqueeze(0)
|
103 |
+
.expand((token_ids_tensor.shape[0], -1)),
|
104 |
+
],
|
105 |
+
dim=1,
|
106 |
+
).to(torch.int64)
|
107 |
+
|
108 |
+
# 256 token_ids, 256 prompt tokens, 1 semantic_infer token as the last column
|
109 |
+
assert (
|
110 |
+
input_tensor.shape[1] == 256 + 256 + 1
|
111 |
+
), f"expecting tensor shape [batch, 513], received {input_tensor.shape}"
|
112 |
+
|
113 |
+
with inference_mode():
|
114 |
+
output: torch.Tensor = _generate_semantic(
|
115 |
+
model=model,
|
116 |
+
x=input_tensor,
|
117 |
+
temperature=temperature,
|
118 |
+
top_k=semantic_top_k,
|
119 |
+
top_p=semantic_top_p,
|
120 |
+
min_eos_p=min_eos_p,
|
121 |
+
max_gen_duration_s=max_gen_duration_second,
|
122 |
+
allow_early_stop=allow_early_stop,
|
123 |
+
use_kv_caching=use_kv_caching,
|
124 |
+
silent=silent,
|
125 |
+
)
|
126 |
+
|
127 |
+
validate_semantic_token_output(output)
|
128 |
+
return output
|
129 |
+
|
130 |
+
|
131 |
+
def _generate_semantic(
|
132 |
+
model: GPT,
|
133 |
+
x: torch.Tensor,
|
134 |
+
temperature: float = 0.7,
|
135 |
+
top_k: Optional[int] = None,
|
136 |
+
top_p: Optional[float] = None,
|
137 |
+
min_eos_p: float = 0.2,
|
138 |
+
max_gen_duration_s: Optional[float] = None,
|
139 |
+
allow_early_stop: bool = True,
|
140 |
+
use_kv_caching: bool = False,
|
141 |
+
silent: bool = False,
|
142 |
+
) -> torch.Tensor:
|
143 |
+
# Maximum number of tokens to generate
|
144 |
+
max_steps = 2048
|
145 |
+
|
146 |
+
# Initialize progress bar for user feedback (custom due to unpredictable stopping)
|
147 |
+
progress_bar = tqdm(
|
148 |
+
total=max_steps, disable=silent, desc="Generating semantic tokens"
|
149 |
+
)
|
150 |
+
last_progress = 0
|
151 |
+
|
152 |
+
# Key-value cache for attention optimization
|
153 |
+
kv_cache = None
|
154 |
+
|
155 |
+
# Autoregressive generation loop
|
156 |
+
for step in range(max_steps):
|
157 |
+
# Determine input based on KV caching
|
158 |
+
if use_kv_caching and kv_cache is not None:
|
159 |
+
# Use only the last token with cached attention states
|
160 |
+
x_input = x[:, [-1]] # Shape [1, 1]
|
161 |
+
else:
|
162 |
+
# Use full sequence (recomputes attention each time)
|
163 |
+
x_input = x # Shape [1, seq_len]
|
164 |
+
|
165 |
+
# Forward pass through the model
|
166 |
+
logits, kv_cache = model(
|
167 |
+
x_input,
|
168 |
+
merge_context=True, # Merges text and semantic history context
|
169 |
+
past_kv=kv_cache, # Previous attention states
|
170 |
+
use_cache=use_kv_caching, # Enables caching if requested
|
171 |
+
)
|
172 |
+
|
173 |
+
# Sample the next token and check for early stopping
|
174 |
+
next_token, should_stop = _sample_next_token(
|
175 |
+
logits=logits,
|
176 |
+
temperature=temperature,
|
177 |
+
top_k=top_k,
|
178 |
+
top_p=top_p,
|
179 |
+
semantic_eos_token=SEMANTIC_EOS_TOKEN,
|
180 |
+
allow_early_stop=allow_early_stop,
|
181 |
+
min_eos_p=min_eos_p,
|
182 |
+
)
|
183 |
+
|
184 |
+
# Check stopping conditions
|
185 |
+
# only stop if all generations in the batch reached the stopping condition
|
186 |
+
if torch.all(should_stop):
|
187 |
+
progress_bar.update(step - last_progress + 1)
|
188 |
+
break
|
189 |
+
|
190 |
+
if step == max_steps - 1:
|
191 |
+
progress_bar.update()
|
192 |
+
break
|
193 |
+
|
194 |
+
# Append the new token to the sequence
|
195 |
+
x = torch.cat((x, next_token), dim=1)
|
196 |
+
|
197 |
+
# Update duration and progress
|
198 |
+
# total_duration_s += duration_per_step
|
199 |
+
if step > last_progress:
|
200 |
+
progress_bar.update(step - last_progress)
|
201 |
+
last_progress = step
|
202 |
+
|
203 |
+
# Clean up tensors to manage memory
|
204 |
+
del logits, next_token
|
205 |
+
|
206 |
+
# Finalize progress bar
|
207 |
+
progress_bar.total = step + 1
|
208 |
+
progress_bar.close()
|
209 |
+
|
210 |
+
# Extract generated tokens (skip initial 513 context tokens)
|
211 |
+
output = x[:, 256 + 256 + 1 :].detach()
|
212 |
+
|
213 |
+
return output
|
214 |
+
|
215 |
+
|
216 |
+
def _sample_next_token(
|
217 |
+
logits: torch.Tensor, # what is the shape of logits?
|
218 |
+
temperature: float,
|
219 |
+
top_k: Optional[int],
|
220 |
+
top_p: Optional[float],
|
221 |
+
semantic_eos_token: int,
|
222 |
+
allow_early_stop: bool,
|
223 |
+
min_eos_p: Optional[float],
|
224 |
+
) -> Tuple[torch.Tensor, torch.BoolTensor]:
|
225 |
+
"""
|
226 |
+
Sample the next token from logits with optional top-k, top-p filtering and early stopping.
|
227 |
+
|
228 |
+
Args:
|
229 |
+
logits: Tensor of shape [batch, seq, vocab_size] containing model predictions.
|
230 |
+
temperature: Controls randomness of sampling (lower = more deterministic).
|
231 |
+
top_k: If set, keeps only the top-k logits.
|
232 |
+
top_p: If set, applies nucleus (top-p) filtering.
|
233 |
+
vocab_size: Size of the semantic vocabulary (e.g., SEMANTIC_VOCAB_SIZE).
|
234 |
+
allow_early_stop: Whether to check for EOS token or probability threshold.
|
235 |
+
min_eos_p: Minimum probability for EOS to trigger early stop.
|
236 |
+
eos_token: Token ID representing end-of-sequence.
|
237 |
+
|
238 |
+
Returns:
|
239 |
+
Tuple[next_token, should_stop]:
|
240 |
+
- next_token: Sampled token (shape [1]).
|
241 |
+
- should_stop: Whether to stop generation (EOS detected).
|
242 |
+
"""
|
243 |
+
# Extract logits for the last position in the sequence
|
244 |
+
relevant_logits = logits[:, -1, :semantic_eos_token]
|
245 |
+
|
246 |
+
# Append EOS logit if early stopping is allowed
|
247 |
+
if allow_early_stop:
|
248 |
+
eos_logit = logits[:, -1, [semantic_eos_token]]
|
249 |
+
relevant_logits = torch.hstack((relevant_logits, eos_logit))
|
250 |
+
|
251 |
+
# select the token with the highest probability
|
252 |
+
if temperature is None:
|
253 |
+
# next_token shape (B, 1)
|
254 |
+
probs = F.softmax(relevant_logits, dim=-1)
|
255 |
+
next_token = torch.argmax(probs, dim=-1, keepdim=True)
|
256 |
+
# when the model predict a 206 token_id, it continue to predict that same token_id with argmax
|
257 |
+
# we will intentionally avoid that token_id here
|
258 |
+
if torch.any(next_token == 206):
|
259 |
+
next_token = anything_but(probs, 206)
|
260 |
+
|
261 |
+
# do some maneuvers to introduce diversity in the sampling of the next token
|
262 |
+
else:
|
263 |
+
# Apply top-p (nucleus) filtering for diversity
|
264 |
+
if top_p is not None: # this if branch is untested
|
265 |
+
# Convert to NumPy for faster sorting (optimization from original)
|
266 |
+
original_device = relevant_logits.device
|
267 |
+
logits_np = relevant_logits.detach().cpu().type(torch.float32).numpy()
|
268 |
+
sorted_indices = np.argsort(logits_np)[::-1] # Descending order
|
269 |
+
sorted_logits = logits_np[sorted_indices]
|
270 |
+
cumulative_probs = np.cumsum(
|
271 |
+
F.softmax(torch.from_numpy(sorted_logits), dim=-1).numpy()
|
272 |
+
)
|
273 |
+
indices_to_remove = cumulative_probs > top_p
|
274 |
+
# Shift to keep at least one
|
275 |
+
indices_to_remove[1:] = indices_to_remove[:-1].copy()
|
276 |
+
indices_to_remove[0] = False # Ensure top token stays
|
277 |
+
logits_np[sorted_indices[indices_to_remove]] = -np.inf
|
278 |
+
relevant_logits = torch.from_numpy(logits_np).to(original_device)
|
279 |
+
|
280 |
+
# Apply top-k filtering for diversity
|
281 |
+
if top_k is not None:
|
282 |
+
top_values, _ = torch.topk(
|
283 |
+
relevant_logits, min(top_k, relevant_logits.size(-1))
|
284 |
+
)
|
285 |
+
# compare the whole logit tensor to its k_th largest value, batch wise
|
286 |
+
relevant_logits[relevant_logits < top_values[:, [-1]]] = -float("Inf")
|
287 |
+
|
288 |
+
# Compute probabilities with temperature scaling
|
289 |
+
probs = F.softmax(relevant_logits / temperature, dim=-1)
|
290 |
+
|
291 |
+
# Sample the next token
|
292 |
+
next_token = torch.multinomial(probs, num_samples=1).to(torch.int32)
|
293 |
+
|
294 |
+
# Check for early stopping conditions for each sequence in the batch
|
295 |
+
if allow_early_stop:
|
296 |
+
# EOS token is vocab_size when appended
|
297 |
+
is_eos_token = (next_token == semantic_eos_token).flatten()
|
298 |
+
eos_prob_high = min_eos_p is not None and probs[:, -1] >= min_eos_p
|
299 |
+
should_stop = torch.logical_or(is_eos_token, eos_prob_high)
|
300 |
+
|
301 |
+
# when batch dimension is 1, next_token is a 1D array, need to make it 2D
|
302 |
+
if len(next_token.shape) == 1:
|
303 |
+
next_token = next_token.unsqueeze(0)
|
304 |
+
return next_token, should_stop
|
305 |
+
|
306 |
+
|
307 |
+
# select the second largest probability token if the argmax is the avoided token
|
308 |
+
# otherwise select the argmax token
|
309 |
+
def anything_but(probs: torch.Tensor, avoid_id: int) -> torch.Tensor:
|
310 |
+
# probs shape (B, C)
|
311 |
+
# return tensor shape (B, 1)
|
312 |
+
values, indices = torch.topk(probs, 2, dim=-1)
|
313 |
+
selected = []
|
314 |
+
# loop over the batch dimension
|
315 |
+
for b in range(probs.shape[0]):
|
316 |
+
if indices[b, 0] == avoid_id:
|
317 |
+
selected.append(indices[b, 1])
|
318 |
+
continue
|
319 |
+
selected.append(indices[b, 0])
|
320 |
+
return torch.tensor(selected, dtype=torch.int32, device=probs.device).unsqueeze(1)
|
321 |
+
|
322 |
+
|
323 |
+
def validate_semantic_token_output(output: torch.Tensor) -> None:
|
324 |
+
assert torch.all(
|
325 |
+
(0 <= output) & (output <= SEMANTIC_VOCAB_SIZE)
|
326 |
+
), "unexpected output tokens"
|
327 |
+
|
328 |
+
|
329 |
+
# preprocess the texts for the generate_text_semantic model
|
330 |
+
def _preprocess_texts(texts: List[str]) -> List[str]:
|
331 |
+
return [re.sub(r"\s+", " ", text).strip() for text in texts]
|
332 |
+
|
333 |
+
|
334 |
+
def trim_or_pad_array(
|
335 |
+
array: Union[np.ndarray, torch.Tensor], pad_token: int, max_length: int = 256
|
336 |
+
) -> torch.Tensor:
|
337 |
+
"""
|
338 |
+
Trim on the left (keep the right most tokens), pad on the right
|
339 |
+
"""
|
340 |
+
# Convert np.ndarray to torch.Tensor if necessary
|
341 |
+
if isinstance(array, np.ndarray):
|
342 |
+
tensor = torch.from_numpy(array).to(device=torch.device(env.DEVICE))
|
343 |
+
else: # Already a torch.Tensor
|
344 |
+
tensor = array
|
345 |
+
|
346 |
+
# Get the current length
|
347 |
+
current_length = tensor.shape[0]
|
348 |
+
|
349 |
+
if current_length > max_length:
|
350 |
+
# Trim from the end (last max_length elements)
|
351 |
+
return tensor[-max_length:]
|
352 |
+
|
353 |
+
elif current_length < max_length:
|
354 |
+
# Left pad 0, right pad to max_length
|
355 |
+
padding = (0, max_length - current_length)
|
356 |
+
return torch.nn.functional.pad(
|
357 |
+
tensor, padding, mode="constant", value=pad_token
|
358 |
+
)
|
359 |
+
|
360 |
+
# If length equals max_length, just return as is
|
361 |
+
return tensor
|
core/bark/voice_clone.py
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torchaudio
|
3 |
+
from typing import Optional
|
4 |
+
|
5 |
+
from core.utils import read_audio_file
|
6 |
+
from core.bark import encodec_encode_audio
|
7 |
+
|
8 |
+
from core.model.hubert import HuBERTForBarkSemantic
|
9 |
+
from core.memory import model_manager, ModelEnum
|
10 |
+
from core.bark.custom_context import InferenceContext
|
11 |
+
from core.data_model import *
|
12 |
+
|
13 |
+
|
14 |
+
HUBERT_SAMPLE_RATE = 16000
|
15 |
+
|
16 |
+
|
17 |
+
def generate_semantic_tokens_from_hubert(
|
18 |
+
waves: torch.Tensor,
|
19 |
+
audio_sample_rate: int,
|
20 |
+
temperature: float,
|
21 |
+
eos_p: float,
|
22 |
+
max_length: int,
|
23 |
+
device: Optional[torch.device],
|
24 |
+
inference_dtype: torch.dtype = torch.float32,
|
25 |
+
) -> torch.Tensor:
|
26 |
+
"""
|
27 |
+
Generate semantic tokens from audio using the HuBERT model.
|
28 |
+
|
29 |
+
Args:
|
30 |
+
audio: 2D tensor of raw audio samples (shape: [B, T], where T is the number of samples)
|
31 |
+
sample_rate: Sample rate of the input audio (default: 24000, matching EnCodec in BARK)
|
32 |
+
hubert_model_name: Name of the HuBERT model from Hugging Face (default: facebook/hubert-large-ls960-ft)
|
33 |
+
device: Torch device to run the model on (defaults to CUDA if available, else CPU)
|
34 |
+
max_length: Maximum length of semantic tokens to return (optional, for truncation)
|
35 |
+
|
36 |
+
Returns:
|
37 |
+
torch.Tensor: 1D tensor of semantic tokens (e.g., shape [N], where N is the sequence length)
|
38 |
+
|
39 |
+
Raises:
|
40 |
+
RuntimeError: If HuBERT model loading or processing fails
|
41 |
+
"""
|
42 |
+
assert (
|
43 |
+
len(waves.shape) == 2
|
44 |
+
), f"expecting a tensor of shape [B, T], got {waves.shape}"
|
45 |
+
waves = waves.to(device)
|
46 |
+
|
47 |
+
# # HuBERT expects audio at 16kHz, resample if necessary
|
48 |
+
if audio_sample_rate != HUBERT_SAMPLE_RATE:
|
49 |
+
resampler = torchaudio.transforms.Resample(
|
50 |
+
orig_freq=audio_sample_rate, new_freq=HUBERT_SAMPLE_RATE
|
51 |
+
).to(device)
|
52 |
+
waves = resampler(waves)
|
53 |
+
|
54 |
+
model = model_manager.get_model(ModelEnum.HuBERTBaseForBarkSemantic.value).model
|
55 |
+
|
56 |
+
assert isinstance(
|
57 |
+
model, HuBERTForBarkSemantic
|
58 |
+
), f"expecting HuBERTForBarkSemantic model type, received {type(model)}"
|
59 |
+
|
60 |
+
waves = waves.to(dtype=inference_dtype)
|
61 |
+
model = model.to(dtype=inference_dtype)
|
62 |
+
|
63 |
+
with InferenceContext():
|
64 |
+
predictions: torch.Tensor = model.generate(
|
65 |
+
wav_input=waves, temperature=temperature, eos_p=eos_p, max_length=max_length
|
66 |
+
)
|
67 |
+
|
68 |
+
return predictions
|
69 |
+
|
70 |
+
|
71 |
+
def create_bark_prompt(
|
72 |
+
audio_file: AudioFile, temperature: float, eos_p: float, device: torch.device
|
73 |
+
) -> BarkPrompt:
|
74 |
+
"""
|
75 |
+
Turn raw audio into valid BARK prompt. When given a raw audio file, use this function
|
76 |
+
to generate a valid BARK prompt
|
77 |
+
"""
|
78 |
+
# Read the audio
|
79 |
+
raw_audio = read_audio_file(
|
80 |
+
path=audio_file.audio_file_path,
|
81 |
+
target_sample_rate=HUBERT_SAMPLE_RATE,
|
82 |
+
channels=1,
|
83 |
+
max_duration=15,
|
84 |
+
)
|
85 |
+
|
86 |
+
audio_tensor = torch.tensor(raw_audio.astype(np.float32), device=device)
|
87 |
+
# Generate semantic tokens from audio using HuBERT
|
88 |
+
semantic_tokens: torch.Tensor = generate_semantic_tokens_from_hubert(
|
89 |
+
waves=audio_tensor.unsqueeze(0),
|
90 |
+
audio_sample_rate=16000,
|
91 |
+
temperature=temperature,
|
92 |
+
eos_p=eos_p,
|
93 |
+
max_length=600,
|
94 |
+
device=device,
|
95 |
+
)
|
96 |
+
|
97 |
+
# Generate codebook tokens using EnCodec
|
98 |
+
codes = encodec_encode_audio(
|
99 |
+
audio_sample=torch.from_numpy(raw_audio[None]),
|
100 |
+
audio_sample_rate=HUBERT_SAMPLE_RATE,
|
101 |
+
)
|
102 |
+
|
103 |
+
# Assuming codes has shape [num_codebooks, T], typically 8 codebooks for 24kHz
|
104 |
+
return BarkPrompt(semantic_tokens, codes[:2, :], codes[:, :])
|
core/data_model/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from core.data_model.bark import *
|
core/data_model/bark.py
ADDED
@@ -0,0 +1,337 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
from pathlib import Path
|
4 |
+
import torch
|
5 |
+
|
6 |
+
from dataclasses import dataclass, asdict, fields
|
7 |
+
import numpy as np
|
8 |
+
from enum import Enum
|
9 |
+
from pydantic import BaseModel, Field
|
10 |
+
from typing import Optional, Union, List, Literal
|
11 |
+
from datetime import datetime
|
12 |
+
from core.utils import save_audio_file, read_audio_file
|
13 |
+
|
14 |
+
|
15 |
+
@dataclass
|
16 |
+
class BarkGenerationConfig:
|
17 |
+
semantic_top_k: Union[int, None] = 1000 # a tenth of the semantic vocab size
|
18 |
+
coarse_top_k: Union[int, None] = 100 # a tenth of the coarse codebook size
|
19 |
+
semantic_top_p: Union[int, None] = None
|
20 |
+
coarse_top_p: Union[int, None] = None
|
21 |
+
min_eos_p: float = 0.5
|
22 |
+
max_gen_duration_second: Union[float, None] = None
|
23 |
+
allow_early_stop: bool = True
|
24 |
+
use_kv_caching: bool = True
|
25 |
+
max_coarse_history: int = 630
|
26 |
+
sliding_window_length: int = 60
|
27 |
+
max_token_per_example: int = 256
|
28 |
+
# set to None to use argmax sampling
|
29 |
+
temperature: float = 0.6
|
30 |
+
generate_coarse_temperature: float = 0.6
|
31 |
+
# set this to None if you want to use argmax to generate fine token
|
32 |
+
generate_fine_temperature: float = 0.6
|
33 |
+
use_small_model: bool = True
|
34 |
+
|
35 |
+
def __init__(self, **kwargs):
|
36 |
+
# Get field names from dataclass
|
37 |
+
valid_fields = {f.name for f in fields(self)}
|
38 |
+
# Set only known fields
|
39 |
+
for key, value in kwargs.items():
|
40 |
+
if key in valid_fields:
|
41 |
+
setattr(self, key, value)
|
42 |
+
|
43 |
+
def to_dict(self) -> dict:
|
44 |
+
return asdict(self)
|
45 |
+
|
46 |
+
@classmethod
|
47 |
+
def from_dict(cls, data: dict) -> "BarkGenerationConfig":
|
48 |
+
return cls(**data)
|
49 |
+
|
50 |
+
|
51 |
+
@dataclass
|
52 |
+
class BarkPrompt:
|
53 |
+
"""
|
54 |
+
semantic_prompt shape: (T)
|
55 |
+
coarse_prompt shape: (2, T)
|
56 |
+
fine_prompt shape: (8, T)
|
57 |
+
those T are different depends on the rate of token type per second
|
58 |
+
"""
|
59 |
+
|
60 |
+
semantic_prompt: torch.Tensor
|
61 |
+
coarse_prompt: torch.Tensor
|
62 |
+
fine_prompt: torch.Tensor
|
63 |
+
|
64 |
+
def save_prompt(self, file_path: str) -> bool:
|
65 |
+
"""
|
66 |
+
Save all 3 prompts to disk as JSON. Return True if success, False if error
|
67 |
+
"""
|
68 |
+
# Ensure the directory exists
|
69 |
+
directory = os.path.dirname(file_path)
|
70 |
+
if directory: # If there's a directory component
|
71 |
+
os.makedirs(directory, exist_ok=True)
|
72 |
+
|
73 |
+
data = {
|
74 |
+
"semantic_prompt": self.semantic_prompt.detach().cpu().tolist(),
|
75 |
+
"coarse_prompt": self.coarse_prompt.detach().cpu().tolist(),
|
76 |
+
"fine_prompt": self.fine_prompt.detach().cpu().tolist(),
|
77 |
+
}
|
78 |
+
|
79 |
+
if not file_path.endswith(".json"):
|
80 |
+
file_path += ".json"
|
81 |
+
|
82 |
+
try:
|
83 |
+
with open(file_path, "w", encoding="utf-8") as f:
|
84 |
+
json.dump(data, f)
|
85 |
+
return True
|
86 |
+
except Exception:
|
87 |
+
return False
|
88 |
+
|
89 |
+
@classmethod
|
90 |
+
def load_prompt(cls, file_path: str, device: torch.device) -> "BarkPrompt":
|
91 |
+
"""
|
92 |
+
Load a prompt from disk. File to load can be either a .json or .npz file
|
93 |
+
"""
|
94 |
+
try:
|
95 |
+
if file_path.endswith(".json"):
|
96 |
+
with open(file_path, "r", encoding="utf-8") as f:
|
97 |
+
prompt = json.load(f)
|
98 |
+
|
99 |
+
assert (
|
100 |
+
"semantic_prompt" in prompt
|
101 |
+
and "coarse_prompt" in prompt
|
102 |
+
and "fine_prompt" in prompt
|
103 |
+
), f"invalid prompt data {prompt}"
|
104 |
+
|
105 |
+
semantic_prompt = torch.tensor(prompt["semantic_prompt"])
|
106 |
+
coarse_prompt = torch.tensor(prompt["coarse_prompt"])
|
107 |
+
fine_prompt = torch.tensor(prompt["fine_prompt"])
|
108 |
+
|
109 |
+
elif file_path.endswith(".npz"):
|
110 |
+
with np.load(file_path) as data:
|
111 |
+
assert (
|
112 |
+
"semantic_prompt" in data
|
113 |
+
and "coarse_prompt" in data
|
114 |
+
and "fine_prompt" in data
|
115 |
+
), f"invalid prompt data in NPZ file"
|
116 |
+
|
117 |
+
semantic_prompt = torch.from_numpy(data["semantic_prompt"])
|
118 |
+
coarse_prompt = torch.from_numpy(data["coarse_prompt"])
|
119 |
+
fine_prompt = torch.from_numpy(data["fine_prompt"])
|
120 |
+
|
121 |
+
else:
|
122 |
+
raise ValueError("Unsupported file format. Use .json or .npz")
|
123 |
+
|
124 |
+
# Convert to device and dtype after loading
|
125 |
+
semantic_prompt = semantic_prompt.to(device=device, dtype=torch.int32)
|
126 |
+
coarse_prompt = coarse_prompt.to(device=device, dtype=torch.int32)
|
127 |
+
fine_prompt = fine_prompt.to(device=device, dtype=torch.int32)
|
128 |
+
|
129 |
+
# Shape checks remain the same
|
130 |
+
if len(semantic_prompt.shape) == 2:
|
131 |
+
semantic_prompt = semantic_prompt[0, :]
|
132 |
+
assert (
|
133 |
+
len(semantic_prompt.shape) == 1
|
134 |
+
), "expecting semantic_prompt as a 1D array"
|
135 |
+
|
136 |
+
assert (
|
137 |
+
coarse_prompt.shape[0] == 2
|
138 |
+
), "expecting coarse_prompt has 2 code book dimension"
|
139 |
+
|
140 |
+
assert (
|
141 |
+
fine_prompt.shape[0] == 8
|
142 |
+
), "expecting fine_prompt has 8 code book dimension"
|
143 |
+
|
144 |
+
return cls(semantic_prompt, coarse_prompt, fine_prompt)
|
145 |
+
|
146 |
+
except Exception as e:
|
147 |
+
raise ValueError(f"Failed to load file: {str(e)}")
|
148 |
+
|
149 |
+
|
150 |
+
class AudioFile(BaseModel):
|
151 |
+
"""Model for validating raw audio prompt inputs."""
|
152 |
+
|
153 |
+
audio_file_path: str = Field(..., description="Path to the audio file")
|
154 |
+
max_duration: int = Field(
|
155 |
+
..., ge=1, description="Maximum duration of the audio in seconds"
|
156 |
+
)
|
157 |
+
|
158 |
+
def get_default_prompt_name(self) -> str:
|
159 |
+
audio_file_name = Path(self.audio_file_path).name
|
160 |
+
return f"{audio_file_name}_{datetime.now().strftime('%Y_%m_%d_%H_%M')}"
|
161 |
+
|
162 |
+
|
163 |
+
class TextToAudioInput(BaseModel):
|
164 |
+
"""Model for validating inputs to the text-to-audio generation function."""
|
165 |
+
|
166 |
+
texts: List[str] = Field(
|
167 |
+
..., min_items=1, description="List of text strings to convert to audio"
|
168 |
+
)
|
169 |
+
audio_prompt: Optional[Union[AudioFile, str]] = Field(
|
170 |
+
None, description="Optional audio prompt (raw or file path)"
|
171 |
+
)
|
172 |
+
sample_rate: int = Field(
|
173 |
+
default=24000, ge=1, description="Sample rate for generated audio"
|
174 |
+
)
|
175 |
+
device: Optional[str] = Field(
|
176 |
+
None, description="Device to use for generation (e.g., 'cuda', 'cpu')"
|
177 |
+
)
|
178 |
+
save_path: str = Field(
|
179 |
+
default="./artifact", description="Directory to save generated audio files"
|
180 |
+
)
|
181 |
+
|
182 |
+
|
183 |
+
class TextToAudioModel(Enum):
|
184 |
+
BARK = "BARK"
|
185 |
+
|
186 |
+
|
187 |
+
@dataclass
|
188 |
+
class WavSemantic:
|
189 |
+
"""
|
190 |
+
An example of a pair (wav, semantic) for training a model to predict semantic from audio
|
191 |
+
"""
|
192 |
+
|
193 |
+
text: str
|
194 |
+
wav: np.ndarray
|
195 |
+
semantic: np.ndarray
|
196 |
+
|
197 |
+
|
198 |
+
@dataclass
|
199 |
+
class WavSemanticDataset:
|
200 |
+
sample_rate: int
|
201 |
+
semantic_generation_config: BarkGenerationConfig
|
202 |
+
bark_model_type: Literal["small", "large"]
|
203 |
+
data: List[WavSemantic]
|
204 |
+
|
205 |
+
def save(self, save_path: str, save_raw_audio: bool) -> None:
|
206 |
+
"""
|
207 |
+
Save this WavSemanticDataset instance to disk at the specified path with compression.
|
208 |
+
|
209 |
+
Args:
|
210 |
+
save_path: Directory path where the dataset will be saved (default: './data').
|
211 |
+
"""
|
212 |
+
# Ensure the save directory exists
|
213 |
+
save_dir = Path(save_path)
|
214 |
+
save_dir.mkdir(parents=True, exist_ok=True)
|
215 |
+
|
216 |
+
# this allows continuous saving of data, e.g save every new batch of data generated
|
217 |
+
if not os.path.exists(save_dir / "metadata.json"):
|
218 |
+
# Prepare metadata dictionary using instance attributes
|
219 |
+
metadata = {
|
220 |
+
"sample_rate": self.sample_rate,
|
221 |
+
"semantic_generation_config": self.semantic_generation_config.to_dict(),
|
222 |
+
"bark_model_type": self.bark_model_type,
|
223 |
+
}
|
224 |
+
|
225 |
+
# Save metadata as JSON
|
226 |
+
with open(save_dir / "metadata.json", "w") as f:
|
227 |
+
json.dump(metadata, f, indent=2)
|
228 |
+
|
229 |
+
next_index = self._get_latest_saved_file_index(save_path) + 1
|
230 |
+
# Save each WavSemantic sample
|
231 |
+
for i, sample in enumerate(self.data):
|
232 |
+
sample_dir = save_dir / f"sample_{i+next_index}"
|
233 |
+
sample_dir.mkdir(exist_ok=True)
|
234 |
+
|
235 |
+
# Save text
|
236 |
+
with open(sample_dir / "text.txt", "w") as f:
|
237 |
+
f.write(sample.text)
|
238 |
+
|
239 |
+
# Save wav and semantic in a single compressed .npz file
|
240 |
+
if save_raw_audio:
|
241 |
+
save_audio_file(
|
242 |
+
sample.wav, self.sample_rate, str(sample_dir / "audio.wav")
|
243 |
+
)
|
244 |
+
with open(sample_dir / "semantic.json", "w") as f:
|
245 |
+
json.dump(sample.semantic.tolist(), f)
|
246 |
+
else:
|
247 |
+
np.savez_compressed(
|
248 |
+
sample_dir / "data.npz", wav=sample.wav, semantic=sample.semantic
|
249 |
+
)
|
250 |
+
|
251 |
+
@staticmethod
|
252 |
+
def _get_latest_saved_file_index(dataset_path: str) -> int:
|
253 |
+
file_names = os.listdir(dataset_path)
|
254 |
+
file_names.remove("metadata.json")
|
255 |
+
if len(file_names) == 0:
|
256 |
+
return -1
|
257 |
+
|
258 |
+
indices = [
|
259 |
+
int(file_name.split("_")[-1].split(".")[0]) for file_name in file_names
|
260 |
+
]
|
261 |
+
|
262 |
+
return max(indices)
|
263 |
+
|
264 |
+
@classmethod
|
265 |
+
def load(cls, load_path: str, num_samples: int = 5000) -> "WavSemanticDataset":
|
266 |
+
"""
|
267 |
+
Load a WavSemanticDataset from disk at the specified path.
|
268 |
+
|
269 |
+
Args:
|
270 |
+
load_path: Directory path where the dataset is saved.
|
271 |
+
num_samples: maximum number of samples to load from the folder
|
272 |
+
Returns:
|
273 |
+
A new WavSemanticDataset instance loaded from disk.
|
274 |
+
"""
|
275 |
+
load_dir = Path(load_path)
|
276 |
+
if not load_dir.exists():
|
277 |
+
raise FileNotFoundError(f"Directory {load_path} does not exist")
|
278 |
+
|
279 |
+
filenames = os.listdir(load_dir)
|
280 |
+
if len(filenames) == 1:
|
281 |
+
# when there is a folder inside the load_path folder, step into it
|
282 |
+
load_dir = load_dir / filenames[0]
|
283 |
+
filenames = os.listdir(load_dir)
|
284 |
+
|
285 |
+
# Load metadata
|
286 |
+
with open(load_dir / "metadata.json", "r") as f:
|
287 |
+
metadata = json.load(f)
|
288 |
+
|
289 |
+
# Reconstruct semantic_generation_config
|
290 |
+
config = BarkGenerationConfig.from_dict(metadata["semantic_generation_config"])
|
291 |
+
|
292 |
+
# Load each WavSemantic sample
|
293 |
+
data = []
|
294 |
+
for i, filename in enumerate(filenames):
|
295 |
+
if not "sample" in filename:
|
296 |
+
continue
|
297 |
+
sample_dir = load_dir / filename
|
298 |
+
|
299 |
+
# Load text
|
300 |
+
with open(sample_dir / "text.txt", "r") as f:
|
301 |
+
text = f.read()
|
302 |
+
|
303 |
+
# Load compressed wav and semantic from .npz file
|
304 |
+
if os.path.isfile(sample_dir / "data.npz"):
|
305 |
+
with np.load(sample_dir / "data.npz") as npz_data:
|
306 |
+
wav = npz_data["wav"]
|
307 |
+
semantic = npz_data["semantic"]
|
308 |
+
# assuming audio wave file was stored separately from the semantic file
|
309 |
+
else:
|
310 |
+
# assuming "audio.wav" and "semantic.npz" exist in the folder
|
311 |
+
wav = read_audio_file(
|
312 |
+
sample_dir / "audio.wav", metadata["sample_rate"], 1, False, None
|
313 |
+
)
|
314 |
+
if os.path.isfile(sample_dir / "semantic.npz"):
|
315 |
+
with np.load(sample_dir / "semantic.npz") as npz_data:
|
316 |
+
semantic = npz_data["semantic"]
|
317 |
+
elif os.path.isfile(sample_dir / "semantic.json"):
|
318 |
+
with open(sample_dir / "semantic.json") as f:
|
319 |
+
semantic = np.array(json.load(f))
|
320 |
+
|
321 |
+
data.append(WavSemantic(text=text, wav=wav, semantic=semantic))
|
322 |
+
if i > num_samples:
|
323 |
+
break
|
324 |
+
|
325 |
+
# Reconstruct and return the dataset
|
326 |
+
return cls(
|
327 |
+
sample_rate=metadata["sample_rate"],
|
328 |
+
semantic_generation_config=config,
|
329 |
+
bark_model_type=metadata["bark_model_type"],
|
330 |
+
data=data,
|
331 |
+
)
|
332 |
+
|
333 |
+
def __getitem__(self, idx: int) -> WavSemantic:
|
334 |
+
return self.data[idx]
|
335 |
+
|
336 |
+
def __len__(self) -> int:
|
337 |
+
return len(self.data)
|
core/memory/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from core.memory.model_manager import *
|
2 |
+
|
3 |
+
from core.memory.model_manager import *
|
4 |
+
|
5 |
+
from core.memory.common import *
|
core/memory/common.py
ADDED
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import logging
|
3 |
+
from dotenv import load_dotenv
|
4 |
+
from enum import Enum
|
5 |
+
from typing import ClassVar, Dict, Any
|
6 |
+
import torch
|
7 |
+
from huggingface_hub import hf_hub_download
|
8 |
+
|
9 |
+
# Configure logging with a default level (will be updated by EnvVars)
|
10 |
+
logging.basicConfig(level=logging.INFO)
|
11 |
+
logger = logging.getLogger(__name__)
|
12 |
+
|
13 |
+
|
14 |
+
class LogLevel(Enum):
|
15 |
+
"""Enumeration of valid logging levels."""
|
16 |
+
|
17 |
+
DEBUG = "DEBUG"
|
18 |
+
INFO = "INFO"
|
19 |
+
WARNING = "WARNING"
|
20 |
+
ERROR = "ERROR"
|
21 |
+
CRITICAL = "CRITICAL"
|
22 |
+
|
23 |
+
|
24 |
+
def grab_best_device(use_gpu: bool, enable_mps: bool) -> str:
|
25 |
+
"""
|
26 |
+
Determine the best available device for PyTorch operations.
|
27 |
+
|
28 |
+
Args:
|
29 |
+
use_gpu (bool): Whether to prioritize GPU/MPS over CPU.
|
30 |
+
enable_mps (bool): Whether to allow MPS (Metal Performance Shaders) on Apple Silicon.
|
31 |
+
|
32 |
+
Returns:
|
33 |
+
str: Device identifier ("cuda", "mps", or "cpu").
|
34 |
+
"""
|
35 |
+
if use_gpu and torch.cuda.is_available():
|
36 |
+
device = "cuda"
|
37 |
+
logger.debug("Selected CUDA device (GPU available)")
|
38 |
+
elif use_gpu and enable_mps and torch.backends.mps.is_available():
|
39 |
+
device = "mps"
|
40 |
+
logger.debug("Selected MPS device (Apple Silicon GPU available)")
|
41 |
+
else:
|
42 |
+
device = "cpu"
|
43 |
+
logger.debug("Selected CPU device (no GPU/MPS available or disabled)")
|
44 |
+
return device
|
45 |
+
|
46 |
+
|
47 |
+
class EnvVars:
|
48 |
+
"""
|
49 |
+
Class to manage and expose environment variables with type safety and runtime configurability.
|
50 |
+
|
51 |
+
Loads variables from a .env file or system environment, applies defaults if not found, and allows updates
|
52 |
+
at runtime. Variables are stored as instance attributes rather than polluting the global namespace.
|
53 |
+
"""
|
54 |
+
|
55 |
+
# Default values for environment variables
|
56 |
+
_DEFAULTS: ClassVar[Dict[str, Any]] = {
|
57 |
+
"GLOBAL_ENABLE_MPS": True, # Enable PyTorch's Metal Performance Shaders on Apple Silicon
|
58 |
+
"AUDIO_SAMPLE_RATE": 24000, # Default sample rate for audio processing (in Hz)
|
59 |
+
"SUNO_USE_SMALL_MODELS": True, # Use smaller Bark models if True
|
60 |
+
"CACHE_DIR": "./models",
|
61 |
+
"LOG_LEVEL": LogLevel.INFO, # Default logging level
|
62 |
+
"USE_GPU": True, # Whether to prioritize GPU/MPS over CPU
|
63 |
+
}
|
64 |
+
|
65 |
+
def __init__(self) -> None:
|
66 |
+
"""Initialize the EnvVars instance and load variables."""
|
67 |
+
self._vars: Dict[str, Any] = {}
|
68 |
+
self._load_env_vars()
|
69 |
+
self._update_attributes()
|
70 |
+
|
71 |
+
def _load_env_vars(self) -> None:
|
72 |
+
"""Load environment variables from .env file or system, falling back to defaults."""
|
73 |
+
load_dotenv() # Load .env file into os.environ
|
74 |
+
for var_name, default_value in self._DEFAULTS.items():
|
75 |
+
value = os.getenv(var_name)
|
76 |
+
if value is None:
|
77 |
+
logger.info(
|
78 |
+
f"{var_name} not found in environment, using default: {default_value}"
|
79 |
+
)
|
80 |
+
self._vars[var_name] = default_value
|
81 |
+
else:
|
82 |
+
# Convert value to the appropriate type based on default
|
83 |
+
if isinstance(default_value, bool):
|
84 |
+
self._vars[var_name] = value.lower() in ("true", "1", "t")
|
85 |
+
elif isinstance(default_value, int):
|
86 |
+
self._vars[var_name] = int(value)
|
87 |
+
elif isinstance(default_value, float):
|
88 |
+
self._vars[var_name] = float(value)
|
89 |
+
elif isinstance(default_value, LogLevel):
|
90 |
+
self._vars[var_name] = LogLevel(value.upper())
|
91 |
+
else:
|
92 |
+
self._vars[var_name] = value
|
93 |
+
logger.info(
|
94 |
+
f"{var_name} loaded from environment: {self._vars[var_name]}"
|
95 |
+
)
|
96 |
+
|
97 |
+
def _update_attributes(self) -> None:
|
98 |
+
"""Update instance attributes and apply settings (e.g., logging level, device)."""
|
99 |
+
# Set instance attributes
|
100 |
+
self.GLOBAL_ENABLE_MPS: bool = self._vars["GLOBAL_ENABLE_MPS"]
|
101 |
+
self.AUDIO_SAMPLE_RATE: int = self._vars["AUDIO_SAMPLE_RATE"]
|
102 |
+
self.SUNO_USE_SMALL_MODELS: bool = self._vars["SUNO_USE_SMALL_MODELS"]
|
103 |
+
self.CACHE_DIR: str = self._vars["CACHE_DIR"]
|
104 |
+
self.LOG_LEVEL: LogLevel = self._vars["LOG_LEVEL"]
|
105 |
+
self.USE_GPU: bool = self._vars["USE_GPU"]
|
106 |
+
self.DEVICE: str = grab_best_device(self.USE_GPU, self.GLOBAL_ENABLE_MPS)
|
107 |
+
logging.getLogger().setLevel(self.LOG_LEVEL.value)
|
108 |
+
|
109 |
+
def update(self, var_name: str, value: Any) -> None:
|
110 |
+
"""
|
111 |
+
Update an environment variable at runtime and reapply settings.
|
112 |
+
|
113 |
+
Args:
|
114 |
+
var_name (str): Name of the variable to update (must be in _DEFAULTS).
|
115 |
+
value (Any): New value for the variable.
|
116 |
+
|
117 |
+
Raises:
|
118 |
+
KeyError: If var_name is not a recognized environment variable.
|
119 |
+
"""
|
120 |
+
if var_name not in self._DEFAULTS:
|
121 |
+
raise KeyError(f"Unknown environment variable: {var_name}")
|
122 |
+
|
123 |
+
# Convert value to the appropriate type based on default
|
124 |
+
default_type = type(self._DEFAULTS[var_name])
|
125 |
+
if default_type is bool:
|
126 |
+
self._vars[var_name] = bool(
|
127 |
+
value.lower() in ("true", "1", "t") if isinstance(value, str) else value
|
128 |
+
)
|
129 |
+
elif default_type is int:
|
130 |
+
self._vars[var_name] = int(value)
|
131 |
+
elif default_type is float:
|
132 |
+
self._vars[var_name] = float(value)
|
133 |
+
elif default_type is LogLevel:
|
134 |
+
self._vars[var_name] = LogLevel(
|
135 |
+
value.upper() if isinstance(value, str) else value
|
136 |
+
)
|
137 |
+
else:
|
138 |
+
self._vars[var_name] = value
|
139 |
+
|
140 |
+
logger.info(f"Updated {var_name} to {self._vars[var_name]}")
|
141 |
+
self._update_attributes()
|
142 |
+
|
143 |
+
|
144 |
+
# Create global instance to access environment variables
|
145 |
+
env = EnvVars()
|
146 |
+
|
147 |
+
|
148 |
+
def get_cached_or_download_model_from_hf(
|
149 |
+
repo_id: str, file_name: str, cache_dir: str = env.CACHE_DIR
|
150 |
+
) -> str:
|
151 |
+
"""
|
152 |
+
Download a model from Hugging Face Hub if not already cached.
|
153 |
+
|
154 |
+
Args:
|
155 |
+
repo_id (str): The repository ID on Hugging Face Hub (e.g., 'suno/bark').
|
156 |
+
file_name (str): The name of the model file to download (e.g., 'text.pt').
|
157 |
+
cache_dir (str): Directory to store cached models (defaults to env.CACHE_DIR).
|
158 |
+
|
159 |
+
Returns:
|
160 |
+
str: The full path to the downloaded or cached model file.
|
161 |
+
|
162 |
+
Raises:
|
163 |
+
OSError: If the cache directory cannot be created.
|
164 |
+
RuntimeError: If the download from Hugging Face fails.
|
165 |
+
"""
|
166 |
+
# Ensure cache directory exists
|
167 |
+
try:
|
168 |
+
os.makedirs(cache_dir, exist_ok=True)
|
169 |
+
except OSError as e:
|
170 |
+
logger.error(f"Failed to create cache directory {cache_dir}: {str(e)}")
|
171 |
+
raise
|
172 |
+
|
173 |
+
# Check if file is already cached
|
174 |
+
cached_path = os.path.join(cache_dir, file_name)
|
175 |
+
if os.path.exists(cached_path):
|
176 |
+
logger.debug(f"Model found in cache: {cached_path}")
|
177 |
+
return cached_path
|
178 |
+
|
179 |
+
# Download from Hugging Face if not cached
|
180 |
+
logger.info(f"Downloading model {repo_id}/{file_name} to {cache_dir}")
|
181 |
+
try:
|
182 |
+
hf_hub_download(repo_id=repo_id, filename=file_name, local_dir=cache_dir)
|
183 |
+
logger.debug(f"Model downloaded successfully to {cached_path}")
|
184 |
+
return cached_path
|
185 |
+
except Exception as e:
|
186 |
+
logger.error(f"Failed to download model {repo_id}/{file_name}: {str(e)}")
|
187 |
+
raise RuntimeError(f"Failed to download model {repo_id}/{file_name}: {str(e)}")
|
core/memory/model_manager.py
ADDED
@@ -0,0 +1,289 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import psutil
|
2 |
+
import logging
|
3 |
+
from typing import Dict, Optional, Callable, Any, Literal
|
4 |
+
from collections import OrderedDict
|
5 |
+
from threading import Lock
|
6 |
+
import torch
|
7 |
+
from transformers import BertTokenizer
|
8 |
+
from encodec import EncodecModel
|
9 |
+
|
10 |
+
from core.memory.common import get_cached_or_download_model_from_hf, env
|
11 |
+
from core.model.bark import GPTConfig, FineGPTConfig, GPT, FineGPT
|
12 |
+
from core.memory.models import *
|
13 |
+
|
14 |
+
# Configure logging for this module
|
15 |
+
logger = logging.getLogger(__name__)
|
16 |
+
|
17 |
+
|
18 |
+
def clear_cuda_cache() -> None:
|
19 |
+
"""
|
20 |
+
Clear the CUDA memory cache if GPU is available.
|
21 |
+
|
22 |
+
Raises:
|
23 |
+
RuntimeError: If CUDA operations fail unexpectedly.
|
24 |
+
"""
|
25 |
+
if torch.cuda.is_available():
|
26 |
+
try:
|
27 |
+
torch.cuda.empty_cache()
|
28 |
+
torch.cuda.synchronize()
|
29 |
+
logger.debug("CUDA cache cleared successfully")
|
30 |
+
except RuntimeError as e:
|
31 |
+
logger.error(f"Failed to clear CUDA cache: {str(e)}")
|
32 |
+
raise RuntimeError(f"CUDA cache clear failed: {str(e)}")
|
33 |
+
|
34 |
+
|
35 |
+
class ModelManager:
|
36 |
+
"""
|
37 |
+
Manager class for loading, caching, and unloading PyTorch models with memory management.
|
38 |
+
|
39 |
+
Prioritizes GPU memory when available, with an optional `offload_to_cpu` flag to use CPU RAM instead.
|
40 |
+
Uses an LRU (Least Recently Used) cache to keep only the most recently used models in memory.
|
41 |
+
Automatically unloads models when memory usage (GPU or CPU, depending on config) exceeds a threshold
|
42 |
+
or the maximum number of cached models is reached.
|
43 |
+
"""
|
44 |
+
|
45 |
+
def __init__(self, max_models: int = 10, offload_to_cpu: bool = False):
|
46 |
+
"""
|
47 |
+
Initialize the model manager.
|
48 |
+
|
49 |
+
Args:
|
50 |
+
max_models (int): Maximum number of models to keep in memory before unloading (default: 5)
|
51 |
+
offload_to_cpu (bool): If True, use CPU RAM instead of GPU memory (default: False)
|
52 |
+
"""
|
53 |
+
self._models: OrderedDict = OrderedDict() # LRU cache for loaded models
|
54 |
+
self._lock = Lock() # Thread lock for safe concurrent access
|
55 |
+
self._max_models = max_models # Max number of models to cache
|
56 |
+
# Whether to offload models to CPU instead of GPU
|
57 |
+
self._offload_to_cpu = offload_to_cpu
|
58 |
+
self._device = torch.device(env.DEVICE) # Device to load models onto
|
59 |
+
logger.info(f"Model manager initialized with device: {self._device}")
|
60 |
+
|
61 |
+
def _check_memory(self) -> bool:
|
62 |
+
"""
|
63 |
+
Check if current memory usage is below the threshold, focusing on GPU unless offloaded to CPU.
|
64 |
+
|
65 |
+
Returns:
|
66 |
+
bool: True if memory usage is safe, False if it exceeds the threshold
|
67 |
+
"""
|
68 |
+
if self._offload_to_cpu or not torch.cuda.is_available():
|
69 |
+
# Check CPU memory usage
|
70 |
+
mem = psutil.virtual_memory() # System memory stats
|
71 |
+
total_mem_used = mem.used / 1e9 # CPU memory used in GB
|
72 |
+
total_mem_available = mem.total / 1e9 # Total CPU memory in GB
|
73 |
+
else:
|
74 |
+
# Check GPU memory usage
|
75 |
+
total_mem_used = (
|
76 |
+
torch.cuda.memory_allocated() / 1e9
|
77 |
+
) # GPU memory used in GB
|
78 |
+
total_mem_available = (
|
79 |
+
torch.cuda.get_device_properties(0).total_memory / 1e9
|
80 |
+
) # Total GPU memory in GB
|
81 |
+
|
82 |
+
usage_ratio = total_mem_used / total_mem_available
|
83 |
+
logger.debug(
|
84 |
+
f"Memory usage on {self._device}: {usage_ratio:.2%} (threshold: {MEMORY_THRESHOLD})"
|
85 |
+
)
|
86 |
+
return usage_ratio < MEMORY_THRESHOLD
|
87 |
+
|
88 |
+
def _unload_lru_model(self):
|
89 |
+
"""Unload the least recently used model to free memory."""
|
90 |
+
with self._lock:
|
91 |
+
if self._models:
|
92 |
+
# Remove oldest entry
|
93 |
+
model_info, model_instance = self._models.popitem(last=False)
|
94 |
+
logger.info(
|
95 |
+
f"Unloading model {model_info} from {self._device} to free memory"
|
96 |
+
)
|
97 |
+
# Move model to CPU before deletion to ensure GPU memory is freed
|
98 |
+
if not self._offload_to_cpu and torch.cuda.is_available():
|
99 |
+
model_instance.model = model_instance.model.cpu()
|
100 |
+
del model_instance # Explicitly delete reference
|
101 |
+
logger.debug(f"Memory freed from {self._device}")
|
102 |
+
|
103 |
+
def get_model(self, model_info: ModelInfo) -> Model:
|
104 |
+
"""
|
105 |
+
Retrieve or load a model, managing memory constraints on the chosen device (GPU or CPU).
|
106 |
+
|
107 |
+
Args:
|
108 |
+
model_info (ModelInfo): Metadata for the model to load
|
109 |
+
|
110 |
+
Returns:
|
111 |
+
Model: The loaded model instance with config and preprocessor
|
112 |
+
|
113 |
+
Raises:
|
114 |
+
ValueError: If model_info is invalid
|
115 |
+
"""
|
116 |
+
assert isinstance(
|
117 |
+
model_info, ModelInfo
|
118 |
+
), f"invalid model_info type {type(model_info)}"
|
119 |
+
with self._lock:
|
120 |
+
# If model is already loaded, move it to the end (most recently used) and return it
|
121 |
+
if model_info in self._models:
|
122 |
+
self._models.move_to_end(model_info)
|
123 |
+
return self._models[model_info]
|
124 |
+
|
125 |
+
# Ensure memory is available by unloading models if necessary
|
126 |
+
while not self._check_memory() or len(self._models) >= self._max_models:
|
127 |
+
self._unload_lru_model()
|
128 |
+
|
129 |
+
if model_info.load_model is not None:
|
130 |
+
model = model_info.load_model(model_info, torch.device(env.DEVICE))
|
131 |
+
elif model_info.checkpoint_name is not None:
|
132 |
+
model = load_transformers_model(model_info, self._device)
|
133 |
+
elif model_info.repo_id is not None and model_info.file_name is not None:
|
134 |
+
model_file_path = get_cached_or_download_model_from_hf(
|
135 |
+
repo_id=model_info.repo_id, file_name=model_info.file_name
|
136 |
+
)
|
137 |
+
model = load_model_from_file(model_info, model_file_path, self._device)
|
138 |
+
else:
|
139 |
+
raise ValueError(
|
140 |
+
"Invalid model info: must provide checkpoint_name or repo_id/file_name"
|
141 |
+
)
|
142 |
+
|
143 |
+
# Cache the loaded model
|
144 |
+
self._models[model_info] = model
|
145 |
+
clear_cuda_cache()
|
146 |
+
logger.info(f"Loaded and cached model {model_info} on {self._device}")
|
147 |
+
return model
|
148 |
+
|
149 |
+
def unload_model(self, model_info: ModelInfo):
|
150 |
+
"""
|
151 |
+
Manually unload a specific model from memory.
|
152 |
+
|
153 |
+
Args:
|
154 |
+
model_info (ModelInfo): Metadata of the model to unload
|
155 |
+
"""
|
156 |
+
with self._lock:
|
157 |
+
if model_info in self._models:
|
158 |
+
model_instance = self._models[model_info]
|
159 |
+
# Move model to CPU before deletion if on GPU
|
160 |
+
if not self._offload_to_cpu and torch.cuda.is_available():
|
161 |
+
model_instance.model = model_instance.model.cpu()
|
162 |
+
del self._models[model_info]
|
163 |
+
logger.info(f"Manually unloaded model {model_info} from {self._device}")
|
164 |
+
|
165 |
+
|
166 |
+
def load_model_from_file(
|
167 |
+
model_info: ModelInfo, model_file_path: str, device: torch.device
|
168 |
+
) -> Model:
|
169 |
+
"""
|
170 |
+
Load a model from a file (e.g., custom weights from Hugging Face).
|
171 |
+
|
172 |
+
Args:
|
173 |
+
model_info (ModelInfo): Metadata for the model
|
174 |
+
model_file_path (str): Path to the model weights file
|
175 |
+
device (torch.device): Device to load the model onto (CPU or GPU)
|
176 |
+
|
177 |
+
Returns:
|
178 |
+
Model: Loaded model instance
|
179 |
+
"""
|
180 |
+
if model_info.repo_id == "suno/bark":
|
181 |
+
return load_bark_model(model_info, model_file_path, device)
|
182 |
+
if model_info.model_type == "custom_hubert_tokenizer":
|
183 |
+
return load_custom_hubert_tokenizer(model_info, model_file_path, device)
|
184 |
+
raise ValueError(f"Unknown how to load model {model_info}")
|
185 |
+
|
186 |
+
|
187 |
+
# temporary turnoff this hubert
|
188 |
+
def load_custom_hubert_tokenizer(
|
189 |
+
model_info: ModelInfo, model_file_path: str, device: torch.device
|
190 |
+
) -> Model:
|
191 |
+
# Automatically uses the right layers
|
192 |
+
# tokenizer = HuBERTForBarkSemantic.load_from_checkpoint(
|
193 |
+
# model_file_path, torch.device(env.DEVICE)
|
194 |
+
# ).to(device)
|
195 |
+
|
196 |
+
# return Model(model=tokenizer)
|
197 |
+
return Model(model=None)
|
198 |
+
|
199 |
+
|
200 |
+
def load_transformers_model(model_info: ModelInfo, device: torch.device) -> Model:
|
201 |
+
"""
|
202 |
+
Load a model using Hugging Face's transformers library.
|
203 |
+
|
204 |
+
Args:
|
205 |
+
model_info (ModelInfo): Metadata for the model
|
206 |
+
device (torch.device): Device to load the model onto (CPU or GPU)
|
207 |
+
|
208 |
+
Returns:
|
209 |
+
Model: Loaded model instance
|
210 |
+
"""
|
211 |
+
if model_info.checkpoint_name == "facebook/encodec_24khz":
|
212 |
+
model = EncodecModel.encodec_model_24khz()
|
213 |
+
model.encode()
|
214 |
+
model = model.to(device)
|
215 |
+
return Model(model)
|
216 |
+
raise NotImplementedError("Only Encodec 24k supported for now")
|
217 |
+
|
218 |
+
|
219 |
+
def load_bark_model(
|
220 |
+
model_info: ModelInfo, model_file_path: str, device: torch.device
|
221 |
+
) -> Model:
|
222 |
+
"""
|
223 |
+
Load a Bark model from a file.
|
224 |
+
|
225 |
+
Args:
|
226 |
+
model_info (ModelInfo): Metadata for the Bark model
|
227 |
+
model_file_path (str): Path to the model weights file
|
228 |
+
device (torch.device): Device to load the model onto (CPU or GPU)
|
229 |
+
|
230 |
+
Returns:
|
231 |
+
Model: Loaded Bark model instance with config and optional tokenizer
|
232 |
+
"""
|
233 |
+
# Load checkpoint directly to the specified device
|
234 |
+
# weights_only = False only for trusted source
|
235 |
+
checkpoint = torch.load(model_file_path, map_location=device, weights_only=False)
|
236 |
+
ConfigClass, ModelClass = (
|
237 |
+
(GPTConfig, GPT)
|
238 |
+
if model_info.model_type in ["text", "coarse"]
|
239 |
+
else (FineGPTConfig, FineGPT)
|
240 |
+
)
|
241 |
+
|
242 |
+
model_args = preprocess_model_args(checkpoint["model_args"])
|
243 |
+
|
244 |
+
conf = ConfigClass(**model_args)
|
245 |
+
model = ModelClass(conf)
|
246 |
+
state_dict = _update_bark_state_dict(model, checkpoint["model"])
|
247 |
+
model.load_state_dict(state_dict, strict=False)
|
248 |
+
|
249 |
+
model = model.to(device) # Ensure model is on the correct device
|
250 |
+
model.eval()
|
251 |
+
logger.info(f"Loaded Bark model: {model_info} on {device}")
|
252 |
+
|
253 |
+
# Add tokenizer for text models (tokenizer stays on CPU as it doesn't require GPU)
|
254 |
+
preprocessor = (
|
255 |
+
BertTokenizer.from_pretrained("bert-base-multilingual-cased")
|
256 |
+
if model_info.model_type == "text"
|
257 |
+
else None
|
258 |
+
)
|
259 |
+
return Model(model, conf, preprocessor)
|
260 |
+
|
261 |
+
|
262 |
+
def preprocess_model_args(model_args: dict) -> dict:
|
263 |
+
if "input_vocab_size" not in model_args:
|
264 |
+
model_args["input_vocab_size"] = model_args["vocab_size"]
|
265 |
+
model_args["output_vocab_size"] = model_args["vocab_size"]
|
266 |
+
del model_args["vocab_size"]
|
267 |
+
return model_args
|
268 |
+
|
269 |
+
|
270 |
+
def _update_bark_state_dict(model: GPT, state_dict: Dict[str, Any]) -> Dict[str, Any]:
|
271 |
+
"""
|
272 |
+
Update the state dictionary by removing unwanted prefixes (specific to Bark models).
|
273 |
+
|
274 |
+
Args:
|
275 |
+
model (GPT): The model instance to align the state dict with
|
276 |
+
state_dict (Dict[str, Any]): The loaded state dictionary
|
277 |
+
|
278 |
+
Returns:
|
279 |
+
Dict[str, Any]: Updated state dictionary
|
280 |
+
"""
|
281 |
+
unwanted_prefix = "_orig_mod."
|
282 |
+
for key in list(state_dict.keys()):
|
283 |
+
if key.startswith(unwanted_prefix):
|
284 |
+
state_dict[key[len(unwanted_prefix) :]] = state_dict.pop(key)
|
285 |
+
return state_dict
|
286 |
+
|
287 |
+
|
288 |
+
# Instantiate the global model manager with default GPU priority
|
289 |
+
model_manager = ModelManager(offload_to_cpu=False if env.USE_GPU else True)
|
core/memory/models.py
ADDED
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import logging
|
4 |
+
from dataclasses import asdict
|
5 |
+
from typing_extensions import Optional, Callable
|
6 |
+
from dataclasses import dataclass
|
7 |
+
from enum import Enum
|
8 |
+
from transformers import BertTokenizer
|
9 |
+
from encodec import EncodecModel
|
10 |
+
|
11 |
+
import torch
|
12 |
+
from core.model.bark import GPT
|
13 |
+
from core.memory.common import env
|
14 |
+
from core.utils import download_file_from_hf
|
15 |
+
from core.model.hubert import HuBERTForBarkSemantic, HubertForBarkSemanticConfig
|
16 |
+
|
17 |
+
logging.basicConfig(
|
18 |
+
level=logging.INFO,
|
19 |
+
format="%(asctime)s - %(levelname)s - %(message)s",
|
20 |
+
handlers=[logging.StreamHandler(sys.stdout)],
|
21 |
+
)
|
22 |
+
logger = logging.getLogger(__name__)
|
23 |
+
|
24 |
+
# Memory threshold (in percentage) to trigger unloading of models when memory usage gets too high
|
25 |
+
# 90% of available memory; applies to GPU unless offloaded to CPU
|
26 |
+
MEMORY_THRESHOLD = 0.9
|
27 |
+
|
28 |
+
|
29 |
+
@dataclass(frozen=True)
|
30 |
+
class ModelInfo:
|
31 |
+
"""Data structure to hold metadata about a model."""
|
32 |
+
|
33 |
+
# Hugging Face repository ID (e.g., "suno/bark")
|
34 |
+
repo_id: Optional[str] = None
|
35 |
+
# Filename of the model weights (e.g., "text.pt")
|
36 |
+
file_name: Optional[str] = None
|
37 |
+
# Pretrained checkpoint name (e.g., "facebook/encodec_24khz")
|
38 |
+
checkpoint_name: Optional[str] = None
|
39 |
+
# Configuration class for the model
|
40 |
+
config_class: Optional[type] = None
|
41 |
+
# Model class to instantiate
|
42 |
+
model_class: Optional[type] = None
|
43 |
+
# Preprocessor class (e.g., tokenizer)
|
44 |
+
preprocessor_class: Optional[type] = None
|
45 |
+
# Type of model (e.g., "text", "coarse", "encodec")
|
46 |
+
model_type: Optional[str] = None
|
47 |
+
# define the function that load the model
|
48 |
+
load_model: Optional[Callable] = None
|
49 |
+
|
50 |
+
|
51 |
+
@dataclass
|
52 |
+
class Model:
|
53 |
+
"""Container for a loaded model, its configuration, and preprocessor."""
|
54 |
+
|
55 |
+
model: Callable # The PyTorch model instance
|
56 |
+
config: Optional[Callable] = None # Model configuration object
|
57 |
+
# Preprocessor (e.g., tokenizer for text models)
|
58 |
+
preprocessor: Optional[Callable] = None
|
59 |
+
|
60 |
+
|
61 |
+
def _load_encodec_model(model_info: ModelInfo, device: torch.device) -> Model:
|
62 |
+
model = EncodecModel.encodec_model_24khz()
|
63 |
+
model.set_target_bandwidth(6.0)
|
64 |
+
model.eval()
|
65 |
+
model.to(device)
|
66 |
+
return Model(model)
|
67 |
+
|
68 |
+
|
69 |
+
def _load_hubert_base_for_bark_semantic(
|
70 |
+
model_info: ModelInfo, device: torch.device
|
71 |
+
) -> "Model":
|
72 |
+
os.makedirs(env.CACHE_DIR, exist_ok=True)
|
73 |
+
local_file_path = os.path.join(env.CACHE_DIR, model_info.file_name)
|
74 |
+
if not os.path.isfile(local_file_path):
|
75 |
+
logger.info(
|
76 |
+
f"Downloading {model_info.file_name} model from {model_info.repo_id}"
|
77 |
+
)
|
78 |
+
download_file_from_hf(
|
79 |
+
model_info.repo_id, "model", model_info.file_name, env.CACHE_DIR
|
80 |
+
)
|
81 |
+
|
82 |
+
checkpoint = torch.load(local_file_path, map_location=device)
|
83 |
+
|
84 |
+
assert isinstance(
|
85 |
+
checkpoint, dict
|
86 |
+
), "expecting a dictionary, got {type(checkpoint)}"
|
87 |
+
|
88 |
+
state_dict = checkpoint.get("model_state_dict", None)
|
89 |
+
assert (
|
90 |
+
state_dict is not None
|
91 |
+
), f"model_state_dict not in checkpoint, {checkpoint.keys()}"
|
92 |
+
|
93 |
+
model_config = checkpoint.get("config", None)
|
94 |
+
assert model_config is not None, "not found model config in checkpoint"
|
95 |
+
|
96 |
+
config = HubertForBarkSemanticConfig(**model_config)
|
97 |
+
model = HuBERTForBarkSemantic(
|
98 |
+
config=config, load_hubert_pretrained_weights=False, device=device
|
99 |
+
)
|
100 |
+
model.load_state_dict(state_dict=state_dict, strict=True)
|
101 |
+
|
102 |
+
return Model(model=model, config=config, preprocessor=None)
|
103 |
+
|
104 |
+
|
105 |
+
# TODO: refactor this class, each ModelInfo should have its own _load_model function for consistency
|
106 |
+
# and avoid complicated if-else paths
|
107 |
+
class ModelEnum(Enum):
|
108 |
+
"""
|
109 |
+
Enumeration of supported models with their metadata.
|
110 |
+
Each entry maps to a ModelInfo object defining how to load the model.
|
111 |
+
"""
|
112 |
+
|
113 |
+
BARK_TEXT_SMALL = ModelInfo(
|
114 |
+
repo_id="suno/bark",
|
115 |
+
file_name="text.pt",
|
116 |
+
model_type="text",
|
117 |
+
model_class=GPT,
|
118 |
+
preprocessor_class=BertTokenizer,
|
119 |
+
)
|
120 |
+
BARK_COARSE_SMALL = ModelInfo(
|
121 |
+
repo_id="suno/bark", file_name="coarse.pt", model_type="coarse"
|
122 |
+
)
|
123 |
+
BARK_FINE_SMALL = ModelInfo(
|
124 |
+
repo_id="suno/bark", file_name="fine.pt", model_type="fine"
|
125 |
+
)
|
126 |
+
|
127 |
+
BARK_TEXT = ModelInfo(repo_id="suno/bark", file_name="text_2.pt", model_type="text")
|
128 |
+
BARK_COARSE = ModelInfo(
|
129 |
+
repo_id="suno/bark", file_name="coarse_2.pt", model_type="coarse"
|
130 |
+
)
|
131 |
+
BARK_FINE = ModelInfo(repo_id="suno/bark", file_name="fine_2.pt", model_type="fine")
|
132 |
+
|
133 |
+
CustomHuBERTTokenizer = ModelInfo(
|
134 |
+
repo_id="GitMylo/bark-voice-cloning",
|
135 |
+
file_name="quantifier_hubert_base_ls960_14.pth",
|
136 |
+
model_type="custom_hubert_tokenizer",
|
137 |
+
)
|
138 |
+
|
139 |
+
ENCODEC24k = ModelInfo(
|
140 |
+
checkpoint_name="facebook/encodec_24khz",
|
141 |
+
model_type="encodec",
|
142 |
+
load_model=_load_encodec_model,
|
143 |
+
)
|
144 |
+
|
145 |
+
HuBERTBaseForBarkSemantic = ModelInfo(
|
146 |
+
checkpoint_name="facebook/hubert-base-ls960",
|
147 |
+
repo_id="sleeper371/hubert-for-bark-semantic",
|
148 |
+
file_name="hubert_epoch_30_2025_04_06_03_23_eval_loss_0.5520355800787607_acc_0.8344086021505376.pt",
|
149 |
+
load_model=_load_hubert_base_for_bark_semantic,
|
150 |
+
)
|
151 |
+
|
152 |
+
@classmethod
|
153 |
+
def get_model_info(cls, model_name: str) -> ModelInfo:
|
154 |
+
"""
|
155 |
+
Retrieve ModelInfo for a given model name.
|
156 |
+
|
157 |
+
Args:
|
158 |
+
model_name (str): Name of the model (e.g., "BARK_TEXT_SMALL")
|
159 |
+
|
160 |
+
Returns:
|
161 |
+
ModelInfo: Metadata for the requested model
|
162 |
+
|
163 |
+
Raises:
|
164 |
+
ValueError: If the model name is not recognized
|
165 |
+
"""
|
166 |
+
try:
|
167 |
+
return cls[model_name].value
|
168 |
+
except KeyError:
|
169 |
+
raise ValueError(f"Unknown model name: {model_name}")
|
core/model/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from core.model.bark import *
|
core/model/bark.py
ADDED
@@ -0,0 +1,425 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
codes adapted from https://github.com/suno-ai/bark
|
3 |
+
"""
|
4 |
+
|
5 |
+
import math
|
6 |
+
from dataclasses import dataclass
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
from torch.nn import functional as F
|
11 |
+
|
12 |
+
|
13 |
+
@dataclass
|
14 |
+
class GPTConfig:
|
15 |
+
block_size: int = 1024
|
16 |
+
input_vocab_size: int = 10_048
|
17 |
+
output_vocab_size: int = 10_048
|
18 |
+
n_layer: int = 12
|
19 |
+
n_head: int = 12
|
20 |
+
n_embd: int = 768
|
21 |
+
dropout: float = 0.0
|
22 |
+
bias: bool = (
|
23 |
+
True # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster
|
24 |
+
)
|
25 |
+
|
26 |
+
|
27 |
+
@dataclass
|
28 |
+
class FineGPTConfig(GPTConfig):
|
29 |
+
n_codes_total: int = 8
|
30 |
+
n_codes_given: int = 1
|
31 |
+
|
32 |
+
|
33 |
+
class LayerNorm(nn.Module):
|
34 |
+
"""LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False"""
|
35 |
+
|
36 |
+
def __init__(self, ndim: int, bias: bool) -> None:
|
37 |
+
super().__init__()
|
38 |
+
self.weight = nn.Parameter(torch.ones(ndim))
|
39 |
+
self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None
|
40 |
+
|
41 |
+
def forward(self, input):
|
42 |
+
return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)
|
43 |
+
|
44 |
+
|
45 |
+
class MLP(nn.Module):
|
46 |
+
|
47 |
+
def __init__(self, config: GPTConfig):
|
48 |
+
super().__init__()
|
49 |
+
self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
|
50 |
+
self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
|
51 |
+
self.dropout = nn.Dropout(config.dropout)
|
52 |
+
self.gelu = nn.GELU()
|
53 |
+
|
54 |
+
def forward(self, x) -> torch.Tensor:
|
55 |
+
x = self.c_fc(x)
|
56 |
+
x = self.gelu(x)
|
57 |
+
x = self.c_proj(x)
|
58 |
+
x = self.dropout(x)
|
59 |
+
return x
|
60 |
+
|
61 |
+
|
62 |
+
class CausalSelfAttention(nn.Module):
|
63 |
+
def __init__(self, config: GPTConfig) -> None:
|
64 |
+
super().__init__()
|
65 |
+
assert config.n_embd % config.n_head == 0
|
66 |
+
|
67 |
+
# key, query, value projections for all heads, but in a batch
|
68 |
+
self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
|
69 |
+
# output projection
|
70 |
+
self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
|
71 |
+
# regularization
|
72 |
+
self.attn_dropout = nn.Dropout(config.dropout)
|
73 |
+
self.resid_dropout = nn.Dropout(config.dropout)
|
74 |
+
self.n_head = config.n_head
|
75 |
+
self.n_embd = config.n_embd
|
76 |
+
self.dropout = config.dropout
|
77 |
+
# flash attention make GPU go brrrrr but support is only in PyTorch nightly and still a bit scary
|
78 |
+
self.flash = hasattr(torch.nn.functional, "scaled_dot_product_attention")
|
79 |
+
if not self.flash:
|
80 |
+
# print("WARNING: using slow attention. Flash Attention atm needs PyTorch nightly and dropout=0.0")
|
81 |
+
# causal mask to ensure that attention is only applied to the left in the input sequence
|
82 |
+
self.register_buffer(
|
83 |
+
"bias",
|
84 |
+
torch.tril(torch.ones(config.block_size, config.block_size)).view(
|
85 |
+
1, 1, config.block_size, config.block_size
|
86 |
+
),
|
87 |
+
)
|
88 |
+
|
89 |
+
def forward(
|
90 |
+
self, x: torch.Tensor, past_kv: torch.Tensor = None, use_cache: bool = False
|
91 |
+
):
|
92 |
+
B, T, C = (
|
93 |
+
x.size()
|
94 |
+
) # batch size, sequence length, embedding dimensionality (n_embd)
|
95 |
+
|
96 |
+
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
|
97 |
+
q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
|
98 |
+
k = k.view(B, T, self.n_head, C // self.n_head).transpose(
|
99 |
+
1, 2
|
100 |
+
) # (B, nh, T, hs)
|
101 |
+
q = q.view(B, T, self.n_head, C // self.n_head).transpose(
|
102 |
+
1, 2
|
103 |
+
) # (B, nh, T, hs)
|
104 |
+
v = v.view(B, T, self.n_head, C // self.n_head).transpose(
|
105 |
+
1, 2
|
106 |
+
) # (B, nh, T, hs)
|
107 |
+
|
108 |
+
if past_kv is not None:
|
109 |
+
past_key = past_kv[0]
|
110 |
+
past_value = past_kv[1]
|
111 |
+
k = torch.cat((past_key, k), dim=-2)
|
112 |
+
v = torch.cat((past_value, v), dim=-2)
|
113 |
+
|
114 |
+
FULL_T = k.shape[-2]
|
115 |
+
|
116 |
+
if use_cache is True:
|
117 |
+
present = (k, v)
|
118 |
+
else:
|
119 |
+
present = None
|
120 |
+
|
121 |
+
# causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
|
122 |
+
if self.flash:
|
123 |
+
# efficient attention using Flash Attention CUDA kernels
|
124 |
+
if past_kv is not None:
|
125 |
+
# When `past_kv` is provided, we're doing incremental decoding and `q.shape[2] == 1`: q only contains
|
126 |
+
# the query for the last token. scaled_dot_product_attention interprets this as the first token in the
|
127 |
+
# sequence, so if is_causal=True it will mask out all attention from it. This is not what we want, so
|
128 |
+
# to work around this we set is_causal=False.
|
129 |
+
is_causal = False
|
130 |
+
else:
|
131 |
+
is_causal = True
|
132 |
+
|
133 |
+
y = torch.nn.functional.scaled_dot_product_attention(
|
134 |
+
q, k, v, dropout_p=self.dropout, is_causal=is_causal
|
135 |
+
)
|
136 |
+
else:
|
137 |
+
# manual implementation of attention
|
138 |
+
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
|
139 |
+
att = att.masked_fill(
|
140 |
+
self.bias[:, :, FULL_T - T : FULL_T, :FULL_T] == 0, float("-inf")
|
141 |
+
)
|
142 |
+
att = F.softmax(att, dim=-1)
|
143 |
+
att = self.attn_dropout(att)
|
144 |
+
y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
|
145 |
+
y = (
|
146 |
+
y.transpose(1, 2).contiguous().view(B, T, C)
|
147 |
+
) # re-assemble all head outputs side by side
|
148 |
+
|
149 |
+
# output projection
|
150 |
+
y = self.resid_dropout(self.c_proj(y))
|
151 |
+
return (y, present)
|
152 |
+
|
153 |
+
|
154 |
+
class Block(nn.Module):
|
155 |
+
|
156 |
+
def __init__(self, config: GPTConfig, layer_idx: int) -> None:
|
157 |
+
super().__init__()
|
158 |
+
self.ln_1 = LayerNorm(config.n_embd, bias=config.bias)
|
159 |
+
self.attn = CausalSelfAttention(config)
|
160 |
+
self.ln_2 = LayerNorm(config.n_embd, bias=config.bias)
|
161 |
+
self.mlp = MLP(config)
|
162 |
+
self.layer_idx = layer_idx
|
163 |
+
|
164 |
+
def forward(
|
165 |
+
self, x: torch.Tensor, past_kv: torch.Tensor = None, use_cache: bool = False
|
166 |
+
):
|
167 |
+
attn_output, prev_kvs = self.attn(
|
168 |
+
self.ln_1(x), past_kv=past_kv, use_cache=use_cache
|
169 |
+
)
|
170 |
+
x = x + attn_output
|
171 |
+
x = x + self.mlp(self.ln_2(x))
|
172 |
+
return (x, prev_kvs)
|
173 |
+
|
174 |
+
|
175 |
+
class GPT(nn.Module):
|
176 |
+
def __init__(self, config: GPTConfig):
|
177 |
+
super().__init__()
|
178 |
+
assert config.input_vocab_size is not None
|
179 |
+
assert config.output_vocab_size is not None
|
180 |
+
assert config.block_size is not None
|
181 |
+
self.config = config
|
182 |
+
|
183 |
+
self.transformer = nn.ModuleDict(
|
184 |
+
dict(
|
185 |
+
wte=nn.Embedding(config.input_vocab_size, config.n_embd),
|
186 |
+
wpe=nn.Embedding(config.block_size, config.n_embd),
|
187 |
+
drop=nn.Dropout(config.dropout),
|
188 |
+
h=nn.ModuleList([Block(config, idx) for idx in range(config.n_layer)]),
|
189 |
+
ln_f=LayerNorm(config.n_embd, bias=config.bias),
|
190 |
+
)
|
191 |
+
)
|
192 |
+
self.lm_head = nn.Linear(config.n_embd, config.output_vocab_size, bias=False)
|
193 |
+
# Note: lm_head lacks bias, implying parameter sharing with wte for efficiency
|
194 |
+
|
195 |
+
def get_num_params(self, non_embedding: bool = True) -> int:
|
196 |
+
"""
|
197 |
+
Return the number of parameters in the model.
|
198 |
+
For non-embedding count (default), the position embeddings get subtracted.
|
199 |
+
The token embeddings would too, except due to the parameter sharing these
|
200 |
+
params are actually used as weights in the final layer, so we include them.
|
201 |
+
"""
|
202 |
+
n_params = sum(p.numel() for p in self.parameters())
|
203 |
+
if non_embedding:
|
204 |
+
n_params -= self.transformer.wte.weight.numel()
|
205 |
+
n_params -= self.transformer.wpe.weight.numel()
|
206 |
+
return n_params
|
207 |
+
|
208 |
+
def forward(
|
209 |
+
self,
|
210 |
+
idx: torch.Tensor,
|
211 |
+
merge_context: bool = False,
|
212 |
+
past_kv: torch.Tensor = None,
|
213 |
+
position_ids: torch.Tensor = None,
|
214 |
+
use_cache: bool = False,
|
215 |
+
):
|
216 |
+
device = idx.device
|
217 |
+
b, t = idx.size()
|
218 |
+
if past_kv is not None:
|
219 |
+
# When past_kv is provided, this is optimized for autoregressive generation
|
220 |
+
assert (
|
221 |
+
t == 1
|
222 |
+
), "should only pass in the last token of the sequence when using kv_cache"
|
223 |
+
# Shape: (b, 1, n_embd), single token case
|
224 |
+
tok_emb = self.transformer.wte(idx)
|
225 |
+
else:
|
226 |
+
if merge_context:
|
227 |
+
# Custom feature: assumes first 256 tokens are one context, next 256 another, rest is sequence
|
228 |
+
assert idx.shape[1] >= 256 + 256 + 1
|
229 |
+
t = idx.shape[1] - 256 # Adjusts t for merged context length
|
230 |
+
else:
|
231 |
+
assert (
|
232 |
+
t <= self.config.block_size
|
233 |
+
), f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
|
234 |
+
|
235 |
+
if merge_context:
|
236 |
+
# Merges two contexts by adding their embeddings, not a standard GPT behavior
|
237 |
+
tok_emb = torch.cat(
|
238 |
+
[
|
239 |
+
self.transformer.wte(idx[:, :256])
|
240 |
+
+ self.transformer.wte(idx[:, 256 : 256 + 256]),
|
241 |
+
self.transformer.wte(idx[:, 256 + 256 :]),
|
242 |
+
],
|
243 |
+
dim=1,
|
244 |
+
)
|
245 |
+
else:
|
246 |
+
tok_emb = self.transformer.wte(idx)
|
247 |
+
|
248 |
+
if past_kv is None:
|
249 |
+
past_length = 0
|
250 |
+
# Empty cache for each layer
|
251 |
+
past_kv = tuple([None] * len(self.transformer.h))
|
252 |
+
else:
|
253 |
+
# Infers prior sequence length from cache
|
254 |
+
past_length = past_kv[0][0].size(-2)
|
255 |
+
|
256 |
+
if position_ids is None:
|
257 |
+
position_ids = torch.arange(
|
258 |
+
past_length, t + past_length, dtype=torch.long, device=device
|
259 |
+
)
|
260 |
+
position_ids = position_ids.unsqueeze(0)
|
261 |
+
assert position_ids.shape == (1, t)
|
262 |
+
|
263 |
+
pos_emb = self.transformer.wpe(position_ids)
|
264 |
+
|
265 |
+
x = self.transformer.drop(tok_emb + pos_emb)
|
266 |
+
|
267 |
+
# Prepares cache for key-value pairs if enabled
|
268 |
+
new_kv = () if use_cache else None
|
269 |
+
|
270 |
+
for i, (block, past_layer_kv) in enumerate(zip(self.transformer.h, past_kv)):
|
271 |
+
x, kv = block(x, past_kv=past_layer_kv, use_cache=use_cache)
|
272 |
+
if use_cache:
|
273 |
+
new_kv = new_kv + (kv,) # Accumulates new key-value pairs for caching
|
274 |
+
|
275 |
+
x = self.transformer.ln_f(x)
|
276 |
+
|
277 |
+
# Optimization: only computes logits for the last token, efficient for generation
|
278 |
+
logits = self.lm_head(x[:, [-1], :]) # Preserves time dim with [-1]
|
279 |
+
|
280 |
+
return (
|
281 |
+
logits,
|
282 |
+
new_kv,
|
283 |
+
) # Returns tuple: logits for next token, cache if requested
|
284 |
+
|
285 |
+
|
286 |
+
class NonCausalSelfAttention(nn.Module):
|
287 |
+
def __init__(self, config):
|
288 |
+
super().__init__()
|
289 |
+
assert config.n_embd % config.n_head == 0
|
290 |
+
# key, query, value projections for all heads, but in a batch
|
291 |
+
self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
|
292 |
+
# output projection
|
293 |
+
self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
|
294 |
+
# regularization
|
295 |
+
self.attn_dropout = nn.Dropout(config.dropout)
|
296 |
+
self.resid_dropout = nn.Dropout(config.dropout)
|
297 |
+
self.n_head = config.n_head
|
298 |
+
self.n_embd = config.n_embd
|
299 |
+
self.dropout = config.dropout
|
300 |
+
# flash attention make GPU go brrrrr but support is only in PyTorch >= 2.0
|
301 |
+
self.flash = hasattr(torch.nn.functional, "scaled_dot_product_attention")
|
302 |
+
|
303 |
+
def forward(self, x):
|
304 |
+
B, T, C = (
|
305 |
+
x.size()
|
306 |
+
) # batch size, sequence length, embedding dimensionality (n_embd)
|
307 |
+
|
308 |
+
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
|
309 |
+
q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
|
310 |
+
k = k.view(B, T, self.n_head, C // self.n_head).transpose(
|
311 |
+
1, 2
|
312 |
+
) # (B, nh, T, hs)
|
313 |
+
q = q.view(B, T, self.n_head, C // self.n_head).transpose(
|
314 |
+
1, 2
|
315 |
+
) # (B, nh, T, hs)
|
316 |
+
v = v.view(B, T, self.n_head, C // self.n_head).transpose(
|
317 |
+
1, 2
|
318 |
+
) # (B, nh, T, hs)
|
319 |
+
|
320 |
+
# causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
|
321 |
+
if self.flash:
|
322 |
+
# efficient attention using Flash Attention CUDA kernels
|
323 |
+
y = torch.nn.functional.scaled_dot_product_attention(
|
324 |
+
q, k, v, attn_mask=None, dropout_p=self.dropout, is_causal=False
|
325 |
+
)
|
326 |
+
else:
|
327 |
+
# manual implementation of attention
|
328 |
+
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
|
329 |
+
att = F.softmax(att, dim=-1)
|
330 |
+
att = self.attn_dropout(att)
|
331 |
+
y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
|
332 |
+
y = (
|
333 |
+
y.transpose(1, 2).contiguous().view(B, T, C)
|
334 |
+
) # re-assemble all head outputs side by side
|
335 |
+
|
336 |
+
# output projection
|
337 |
+
y = self.resid_dropout(self.c_proj(y))
|
338 |
+
return y
|
339 |
+
|
340 |
+
|
341 |
+
class FineBlock(nn.Module):
|
342 |
+
def __init__(self, config):
|
343 |
+
super().__init__()
|
344 |
+
self.ln_1 = nn.LayerNorm(config.n_embd)
|
345 |
+
self.attn = NonCausalSelfAttention(config)
|
346 |
+
self.ln_2 = nn.LayerNorm(config.n_embd)
|
347 |
+
self.mlp = MLP(config)
|
348 |
+
|
349 |
+
def forward(self, x):
|
350 |
+
x = x + self.attn(self.ln_1(x))
|
351 |
+
x = x + self.mlp(self.ln_2(x))
|
352 |
+
return x
|
353 |
+
|
354 |
+
|
355 |
+
class FineGPT(GPT):
|
356 |
+
def __init__(self, config):
|
357 |
+
super().__init__(config)
|
358 |
+
del self.lm_head
|
359 |
+
self.config = config
|
360 |
+
self.n_codes_total = config.n_codes_total
|
361 |
+
self.transformer = nn.ModuleDict(
|
362 |
+
dict(
|
363 |
+
wtes=nn.ModuleList(
|
364 |
+
[
|
365 |
+
nn.Embedding(config.input_vocab_size, config.n_embd)
|
366 |
+
for _ in range(config.n_codes_total)
|
367 |
+
]
|
368 |
+
),
|
369 |
+
wpe=nn.Embedding(config.block_size, config.n_embd),
|
370 |
+
drop=nn.Dropout(config.dropout),
|
371 |
+
h=nn.ModuleList([FineBlock(config) for _ in range(config.n_layer)]),
|
372 |
+
ln_f=nn.LayerNorm(config.n_embd),
|
373 |
+
)
|
374 |
+
)
|
375 |
+
self.lm_heads = nn.ModuleList(
|
376 |
+
[
|
377 |
+
nn.Linear(config.n_embd, config.output_vocab_size, bias=False)
|
378 |
+
for _ in range(config.n_codes_given, self.n_codes_total)
|
379 |
+
]
|
380 |
+
)
|
381 |
+
for i in range(self.n_codes_total - config.n_codes_given):
|
382 |
+
self.transformer.wtes[i + 1].weight = self.lm_heads[i].weight
|
383 |
+
|
384 |
+
def forward(self, pred_idx, idx):
|
385 |
+
device = idx.device
|
386 |
+
b, t, codes = idx.size()
|
387 |
+
assert (
|
388 |
+
t <= self.config.block_size
|
389 |
+
), f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
|
390 |
+
assert pred_idx > 0, "cannot predict 0th codebook"
|
391 |
+
assert codes == self.n_codes_total, (b, t, codes)
|
392 |
+
pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(
|
393 |
+
0
|
394 |
+
) # shape (1, t)
|
395 |
+
|
396 |
+
# forward the GPT model itself
|
397 |
+
tok_embs = [
|
398 |
+
wte(idx[:, :, i]).unsqueeze(-1)
|
399 |
+
for i, wte in enumerate(self.transformer.wtes)
|
400 |
+
] # token embeddings of shape (b, t, n_embd)
|
401 |
+
tok_emb = torch.cat(tok_embs, dim=-1)
|
402 |
+
pos_emb = self.transformer.wpe(
|
403 |
+
pos
|
404 |
+
) # position embeddings of shape (1, t, n_embd)
|
405 |
+
x = tok_emb[:, :, :, : pred_idx + 1].sum(dim=-1)
|
406 |
+
x = self.transformer.drop(x + pos_emb)
|
407 |
+
for block in self.transformer.h:
|
408 |
+
x = block(x)
|
409 |
+
x = self.transformer.ln_f(x)
|
410 |
+
logits = self.lm_heads[pred_idx - self.config.n_codes_given](x)
|
411 |
+
return logits
|
412 |
+
|
413 |
+
def get_num_params(self, non_embedding=True):
|
414 |
+
"""
|
415 |
+
Return the number of parameters in the model.
|
416 |
+
For non-embedding count (default), the position embeddings get subtracted.
|
417 |
+
The token embeddings would too, except due to the parameter sharing these
|
418 |
+
params are actually used as weights in the final layer, so we include them.
|
419 |
+
"""
|
420 |
+
n_params = sum(p.numel() for p in self.parameters())
|
421 |
+
if non_embedding:
|
422 |
+
for wte in self.transformer.wtes:
|
423 |
+
n_params -= wte.weight.numel()
|
424 |
+
n_params -= self.transformer.wpe.weight.numel()
|
425 |
+
return n_params
|
core/model/hubert.py
ADDED
@@ -0,0 +1,237 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
from pathlib import Path
|
3 |
+
from typing import Optional, Tuple, Union, Literal
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
from transformers.modeling_outputs import BaseModelOutput
|
8 |
+
from transformers import HubertModel, AutoConfig, AutoModel
|
9 |
+
|
10 |
+
|
11 |
+
@dataclass
|
12 |
+
class CustomHubertConfig:
|
13 |
+
"""Configuration class for CustomHubert model."""
|
14 |
+
|
15 |
+
# e.g., "facebook/hubert-base-ls960" or "facebook/hubert-large-ll60k"
|
16 |
+
checkpoint_name: str
|
17 |
+
# Layer to extract features from (0-indexed, e.g., 9 for 10th layer)
|
18 |
+
feature_layer: int = 11
|
19 |
+
# Target audio sample rate in Hz
|
20 |
+
target_sample_rate: int = 16000
|
21 |
+
# Optional length multiple for audio trimming
|
22 |
+
seq_len_multiple_of: Optional[int] = None
|
23 |
+
|
24 |
+
|
25 |
+
@dataclass
|
26 |
+
class HubertForBarkSemanticConfig:
|
27 |
+
"""Configuration for HuBERTForBarkSemantic."""
|
28 |
+
|
29 |
+
# # HuBERT model checkpoint for feature extractor layer
|
30 |
+
checkpoint_name: Literal["facebook/hubert-base-ls960", "hubert-large-ls960-ft"]
|
31 |
+
vocab_size: int
|
32 |
+
# Layer to extract features from
|
33 |
+
feature_layer: int = 11
|
34 |
+
# last three tokens for SOS, EOS and PAD tokens
|
35 |
+
# maximum target sequence length
|
36 |
+
max_target_length: int = 2000
|
37 |
+
num_decoder_layer: int = 12
|
38 |
+
sos_token_id: int = 10000
|
39 |
+
eos_token_id: int = 10001
|
40 |
+
|
41 |
+
|
42 |
+
class HubertFeatureExtractor(nn.Module):
|
43 |
+
"""
|
44 |
+
A custom HuBERT model that loads a pretrained model from transformers and extracts
|
45 |
+
features from a specified layer. Processes raw audio waveforms and returns hidden states.
|
46 |
+
|
47 |
+
Args:
|
48 |
+
config (CustomHubertConfig): Configuration specifying checkpoint, layer, and audio settings.
|
49 |
+
device (torch.device, optional): Device to run the model on (e.g., "cuda" or "cpu").
|
50 |
+
"""
|
51 |
+
|
52 |
+
def __init__(
|
53 |
+
self,
|
54 |
+
config: CustomHubertConfig,
|
55 |
+
load_pretrained_weights: bool,
|
56 |
+
device: Optional[torch.device] = None,
|
57 |
+
):
|
58 |
+
super().__init__()
|
59 |
+
self.config = config
|
60 |
+
self.target_sample_rate = config.target_sample_rate
|
61 |
+
|
62 |
+
# Load pretrained HuBERT model from transformers
|
63 |
+
self.hubert_config = AutoConfig.from_pretrained(config.checkpoint_name)
|
64 |
+
if load_pretrained_weights:
|
65 |
+
self.model = HubertModel.from_pretrained(config.checkpoint_name)
|
66 |
+
else:
|
67 |
+
# don't download the pretrained weights, init the model from the config
|
68 |
+
self.model = AutoModel.from_config(self.hubert_config)
|
69 |
+
|
70 |
+
# Validate feature_layer
|
71 |
+
# e.g., 12 for BASE, 24 for LARGE
|
72 |
+
num_layers = self.model.config.num_hidden_layers
|
73 |
+
if not (0 <= config.feature_layer < num_layers):
|
74 |
+
raise ValueError(
|
75 |
+
f"feature_layer must be between 0 and {num_layers - 1}, got {config.feature_layer}"
|
76 |
+
)
|
77 |
+
self.feature_layer = config.feature_layer
|
78 |
+
|
79 |
+
# Move to device if specified
|
80 |
+
if device is not None:
|
81 |
+
self.to(device)
|
82 |
+
|
83 |
+
@property
|
84 |
+
def hidden_size(self) -> int:
|
85 |
+
"""Returns the hidden size of the HuBERT model (e.g., 768 for BASE, 1024 for LARGE)."""
|
86 |
+
return self.model.config.hidden_size
|
87 |
+
|
88 |
+
def forward(
|
89 |
+
self,
|
90 |
+
wav_input: torch.Tensor,
|
91 |
+
) -> torch.Tensor:
|
92 |
+
"""
|
93 |
+
Processes raw audio waveforms through HuBERT and extracts features from the specified layer.
|
94 |
+
Input audio sample rate expected 16k
|
95 |
+
|
96 |
+
Args:
|
97 |
+
wav_input (torch.Tensor): Raw audio waveforms, shape [batch_size, audio_length].
|
98 |
+
return_shape (Tuple[int, int], optional): If provided, reshapes output to [batch_size, seq_length, hidden_size].
|
99 |
+
|
100 |
+
Returns:
|
101 |
+
torch.Tensor: Features from the specified layer. Shape depends on return_shape:
|
102 |
+
- If None: [batch_size * seq_length, hidden_size] (flattened).
|
103 |
+
- If provided: [batch_size, seq_length, hidden_size].
|
104 |
+
"""
|
105 |
+
|
106 |
+
# Forward pass through HuBERT
|
107 |
+
# output_hidden_states=True returns all layer outputs
|
108 |
+
outputs: BaseModelOutput = self.model(
|
109 |
+
input_values=wav_input, output_hidden_states=True, return_dict=True
|
110 |
+
)
|
111 |
+
|
112 |
+
# Extract features from the specified layer (0-indexed)
|
113 |
+
# hidden_states is a tuple of [batch_size, seq_length, hidden_size] for each layer
|
114 |
+
features = outputs.hidden_states[self.feature_layer] # e.g., [2, 500, 768]
|
115 |
+
features = features.contiguous()
|
116 |
+
return features
|
117 |
+
|
118 |
+
|
119 |
+
class HuBERTForBarkSemantic(nn.Module):
|
120 |
+
def __init__(
|
121 |
+
self,
|
122 |
+
config: HubertForBarkSemanticConfig,
|
123 |
+
load_hubert_pretrained_weights: bool = True,
|
124 |
+
device: Optional[torch.device] = None,
|
125 |
+
):
|
126 |
+
super().__init__()
|
127 |
+
self.config = config
|
128 |
+
|
129 |
+
# HuBERT feature extractor
|
130 |
+
hubert_config = CustomHubertConfig(
|
131 |
+
checkpoint_name=config.checkpoint_name,
|
132 |
+
feature_layer=config.feature_layer,
|
133 |
+
)
|
134 |
+
self.hubert = HubertFeatureExtractor(
|
135 |
+
config=hubert_config,
|
136 |
+
load_pretrained_weights=load_hubert_pretrained_weights,
|
137 |
+
device=device,
|
138 |
+
)
|
139 |
+
|
140 |
+
# e.g., 768 for BASE
|
141 |
+
input_size = self.hubert.model.config.hidden_size
|
142 |
+
|
143 |
+
# Transformer Decoder
|
144 |
+
self.decoder_embedding = nn.Embedding(config.vocab_size, input_size)
|
145 |
+
self.pos_embedding = nn.Parameter(
|
146 |
+
torch.zeros(1, config.max_target_length, input_size)
|
147 |
+
)
|
148 |
+
self.decoder = nn.TransformerDecoder(
|
149 |
+
nn.TransformerDecoderLayer(
|
150 |
+
d_model=input_size,
|
151 |
+
nhead=8,
|
152 |
+
dim_feedforward=2048,
|
153 |
+
dropout=0.1,
|
154 |
+
batch_first=True,
|
155 |
+
),
|
156 |
+
num_layers=config.num_decoder_layer, # Adjust as needed
|
157 |
+
)
|
158 |
+
self.fc = nn.Linear(input_size, config.vocab_size)
|
159 |
+
|
160 |
+
if device is not None:
|
161 |
+
self.to(device)
|
162 |
+
|
163 |
+
def save_state_dict(self, save_path: str):
|
164 |
+
torch.save(self.state_dict(), save_path)
|
165 |
+
|
166 |
+
def forward(self, wav_input: torch.Tensor, tgt: torch.Tensor) -> torch.Tensor:
|
167 |
+
"""
|
168 |
+
Forward pass: Extracts HuBERT features and predicts semantic token probabilities.
|
169 |
+
|
170 |
+
Args:
|
171 |
+
wav_input: [batch_size, audio_length] (e.g., [2, 160000])
|
172 |
+
tgt: the target sequence
|
173 |
+
|
174 |
+
Returns:
|
175 |
+
[batch_size, seq_length, vocab_size + 1] (e.g., [2, 500, VOCAB_SIZE])
|
176 |
+
"""
|
177 |
+
memory: torch.Tensor = self.hubert(wav_input) # [B, T, 768]
|
178 |
+
B, T_tgt = tgt.shape
|
179 |
+
tgt_emb = self.decoder_embedding(tgt) + self.pos_embedding[:, :T_tgt, :]
|
180 |
+
tgt_mask = nn.Transformer.generate_square_subsequent_mask(T_tgt).to(tgt.device)
|
181 |
+
|
182 |
+
output: torch.Tensor = self.decoder(tgt_emb, memory, tgt_mask=tgt_mask)
|
183 |
+
logits = self.fc(output)
|
184 |
+
return logits
|
185 |
+
|
186 |
+
@torch.no_grad
|
187 |
+
def generate(
|
188 |
+
self,
|
189 |
+
wav_input: torch.Tensor,
|
190 |
+
temperature: Optional[float] = 0.8,
|
191 |
+
eos_p: Optional[float] = 0.5,
|
192 |
+
max_length: int = 600,
|
193 |
+
) -> torch.Tensor:
|
194 |
+
"""
|
195 |
+
Inference: autoregressive generation.
|
196 |
+
assuming wav_input audio is at 16000 sample rate"""
|
197 |
+
self.eval()
|
198 |
+
memory = self.hubert(wav_input)
|
199 |
+
B = wav_input.shape[0]
|
200 |
+
tgt = torch.full(
|
201 |
+
size=(B, 1), fill_value=self.config.sos_token_id, device=wav_input.device
|
202 |
+
)
|
203 |
+
|
204 |
+
for _ in range(max_length):
|
205 |
+
tgt_emb = (
|
206 |
+
self.decoder_embedding(tgt) + self.pos_embedding[:, : tgt.shape[1], :]
|
207 |
+
)
|
208 |
+
tgt_mask = nn.Transformer.generate_square_subsequent_mask(tgt.shape[1]).to(
|
209 |
+
tgt.device
|
210 |
+
)
|
211 |
+
|
212 |
+
output = self.decoder(tgt_emb, memory, tgt_mask=tgt_mask)
|
213 |
+
# logits shape (B, T', vocab_size)
|
214 |
+
logits: torch.Tensor = self.fc(output[:, -1, :])
|
215 |
+
|
216 |
+
if temperature is not None and temperature > 0:
|
217 |
+
probs = torch.softmax(input=logits / temperature, dim=-1)
|
218 |
+
next_token = torch.multinomial(input=probs, num_samples=1)
|
219 |
+
else:
|
220 |
+
probs = torch.softmax(input=logits, dim=-1)
|
221 |
+
next_token = logits.argmax(dim=-1, keepdim=True)
|
222 |
+
|
223 |
+
# stop if the EOS token probabilities are higher than the provided eos_p
|
224 |
+
if eos_p is not None and eos_p > 0:
|
225 |
+
if torch.all(probs[:, self.config.eos_token_id] > eos_p):
|
226 |
+
break
|
227 |
+
|
228 |
+
# early stopping
|
229 |
+
if torch.all(next_token == self.config.eos_token_id):
|
230 |
+
break
|
231 |
+
|
232 |
+
tgt = torch.cat([tgt, next_token], dim=1)
|
233 |
+
if (next_token == self.config.eos_token_id).all():
|
234 |
+
break
|
235 |
+
|
236 |
+
# remove the [SOS] token from the generated semantic sequences
|
237 |
+
return tgt[:, 1:]
|
core/trainer/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from core.trainer.custom_hubert_trainer import *
|
core/trainer/custom_hubert_trainer.py
ADDED
@@ -0,0 +1,555 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from pathlib import Path
|
3 |
+
from datetime import datetime
|
4 |
+
import logging
|
5 |
+
import sys
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
from torch.optim import Adam
|
10 |
+
from torch.optim.lr_scheduler import LRScheduler, LinearLR
|
11 |
+
from torch.utils.data import Dataset, DataLoader, random_split
|
12 |
+
import torchaudio
|
13 |
+
from tqdm import tqdm
|
14 |
+
|
15 |
+
from typing import Literal, List, Optional, Tuple, Dict, Callable, Union, Any
|
16 |
+
from core.data_model import WavSemantic, WavSemanticDataset
|
17 |
+
from core.utils import read_audio_file, upload_file_to_hf
|
18 |
+
|
19 |
+
# cudnn error about non-contiguous input at the lstm layer, disable it fixed the issue
|
20 |
+
torch.backends.cudnn.enabled = False
|
21 |
+
|
22 |
+
# Set up logging
|
23 |
+
logging.basicConfig(
|
24 |
+
level=logging.INFO,
|
25 |
+
format="%(asctime)s - %(levelname)s - %(message)s",
|
26 |
+
handlers=[logging.StreamHandler(sys.stdout)],
|
27 |
+
)
|
28 |
+
logger = logging.getLogger(__name__)
|
29 |
+
|
30 |
+
|
31 |
+
HUBERT_SAMPLE_RATE = 16000
|
32 |
+
# 10_000 and 10_001 are for SOS and EOS tokens
|
33 |
+
SEMANTIC_PADDING_TOKEN = 10002
|
34 |
+
SOS_TOKEN = 10_000
|
35 |
+
EOS_TOKEN = 10_001
|
36 |
+
|
37 |
+
|
38 |
+
class WavSemanticTorchDataset(Dataset):
|
39 |
+
"""PyTorch Dataset for WavSemantic data with resampling and noise augmentation.
|
40 |
+
Padding is carried out in a collator function.
|
41 |
+
|
42 |
+
Args:
|
43 |
+
samples: List of WavSemantic objects (speech data).
|
44 |
+
orig_sample_rate: Original sample rate of the audio.
|
45 |
+
target_sample_rate: Desired sample rate (default: 16000 Hz).
|
46 |
+
device: Device to move tensors to (optional).
|
47 |
+
noises: List of noise waveforms as NumPy arrays (optional, for augmentation).
|
48 |
+
noises audio must already have sample_rate = target_sample rate, this class doesn't resample it
|
49 |
+
augment_prob: Probability of applying noise augmentation (default: 0.5).
|
50 |
+
"""
|
51 |
+
|
52 |
+
def __init__(
|
53 |
+
self,
|
54 |
+
samples: List["WavSemantic"],
|
55 |
+
orig_sample_rate: int,
|
56 |
+
target_sample_rate: Optional[int] = 16000,
|
57 |
+
device: Optional[torch.device] = None,
|
58 |
+
noises: Optional[List[np.ndarray]] = None,
|
59 |
+
augment_prob: float = 0.5,
|
60 |
+
):
|
61 |
+
self.samples = samples
|
62 |
+
self.orig_sample_rate = orig_sample_rate
|
63 |
+
self.target_sample_rate = target_sample_rate
|
64 |
+
self.device = device
|
65 |
+
self.noises = noises
|
66 |
+
self.augment_prob = augment_prob
|
67 |
+
self.resampler = torchaudio.transforms.Resample(
|
68 |
+
orig_freq=orig_sample_rate, new_freq=target_sample_rate
|
69 |
+
)
|
70 |
+
|
71 |
+
def __len__(self) -> int:
|
72 |
+
return len(self.samples)
|
73 |
+
|
74 |
+
def _normalize_waveform(self, wav: torch.Tensor) -> torch.Tensor:
|
75 |
+
"""Normalize waveform to [-1, 1]."""
|
76 |
+
max_val = wav.abs().max()
|
77 |
+
if max_val > 0:
|
78 |
+
wav = wav / max_val
|
79 |
+
return wav
|
80 |
+
|
81 |
+
def _add_time_varying_noise(
|
82 |
+
self, speech: torch.Tensor, noise: torch.Tensor, snr_db: float
|
83 |
+
) -> torch.Tensor:
|
84 |
+
"""Add noise to a random segment of the speech with fade-in/fade-out."""
|
85 |
+
speech_len = speech.size(0)
|
86 |
+
noise_len = noise.size(0)
|
87 |
+
|
88 |
+
# Match noise length (loop or trim)
|
89 |
+
if noise_len < speech_len:
|
90 |
+
repeats = int(np.ceil(speech_len / noise_len))
|
91 |
+
noise = noise.repeat(repeats)[:speech_len]
|
92 |
+
else:
|
93 |
+
noise = noise[:speech_len]
|
94 |
+
|
95 |
+
# Random segment (50%-100% of speech length)
|
96 |
+
seg_len = int(speech_len * np.random.uniform(0.5, 1.0))
|
97 |
+
start = np.random.randint(0, speech_len - seg_len + 1)
|
98 |
+
end = start + seg_len
|
99 |
+
|
100 |
+
# Compute noise scaling based on SNR
|
101 |
+
speech_energy = torch.mean(speech[start:end] ** 2)
|
102 |
+
noise_energy = torch.mean(noise[start:end] ** 2)
|
103 |
+
snr_linear = 10 ** (snr_db / 10.0)
|
104 |
+
noise_scale = torch.sqrt(speech_energy / (noise_energy * snr_linear + 1e-10))
|
105 |
+
|
106 |
+
# Apply noise to segment with fade-in/fade-out
|
107 |
+
fade_len = min(1000, seg_len // 4) # Fade over 1000 samples or 1/4 segment
|
108 |
+
fade_in = torch.linspace(0, 1, fade_len)
|
109 |
+
fade_out = torch.linspace(1, 0, fade_len)
|
110 |
+
mask = torch.ones(seg_len)
|
111 |
+
if fade_len > 0:
|
112 |
+
mask[:fade_len] = fade_in
|
113 |
+
mask[-fade_len:] = fade_out
|
114 |
+
|
115 |
+
noisy_segment = speech[start:end] + (noise_scale * noise[start:end] * mask)
|
116 |
+
noisy_speech = speech.clone()
|
117 |
+
noisy_speech[start:end] = noisy_segment
|
118 |
+
|
119 |
+
return torch.clamp(noisy_speech, -1, 1)
|
120 |
+
|
121 |
+
def _augment_with_noise(self, wav: torch.Tensor) -> torch.Tensor:
|
122 |
+
"""Augment waveform with random noise mixture."""
|
123 |
+
if not self.noises or len(self.noises) == 0:
|
124 |
+
return wav
|
125 |
+
|
126 |
+
# Decide how many noises to mix (1 or 2)
|
127 |
+
num_noises = np.random.randint(1, 3) # 1 or 2 noises
|
128 |
+
random_indices = np.random.randint(0, len(self.noises), size=num_noises)
|
129 |
+
selected_noises = [self.noises[i] for i in random_indices]
|
130 |
+
noisy_wav = wav.clone()
|
131 |
+
for noise_np in selected_noises:
|
132 |
+
noise = torch.from_numpy(noise_np).float()
|
133 |
+
noise = self._normalize_waveform(noise) # Normalize noise
|
134 |
+
snr_db = np.random.uniform(0, 20) # Random SNR between 0-20 dB
|
135 |
+
noisy_wav = self._add_time_varying_noise(noisy_wav, noise, snr_db)
|
136 |
+
|
137 |
+
# Volume normalization: re-normalize after mixing
|
138 |
+
noisy_wav = self._normalize_waveform(noisy_wav)
|
139 |
+
return noisy_wav
|
140 |
+
|
141 |
+
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, str]:
|
142 |
+
sample = self.samples[idx]
|
143 |
+
|
144 |
+
# Convert NumPy wav to torch tensor and resample
|
145 |
+
wav_tensor = torch.from_numpy(sample.wav).float()
|
146 |
+
if self.orig_sample_rate != self.target_sample_rate:
|
147 |
+
wav_tensor = self.resampler(wav_tensor)
|
148 |
+
|
149 |
+
# Normalize to [-1, 1]
|
150 |
+
wav_tensor = self._normalize_waveform(wav_tensor)
|
151 |
+
|
152 |
+
# Apply noise augmentation with probability
|
153 |
+
if self.noises and np.random.rand() < self.augment_prob:
|
154 |
+
wav_tensor = self._augment_with_noise(wav_tensor)
|
155 |
+
|
156 |
+
# Convert semantic to torch tensor (assuming integer tokens for CTC)
|
157 |
+
semantic_tensor = torch.from_numpy(sample.semantic).long()
|
158 |
+
|
159 |
+
# Move to device if specified
|
160 |
+
if self.device is not None:
|
161 |
+
wav_tensor = wav_tensor.to(self.device)
|
162 |
+
semantic_tensor = semantic_tensor.to(self.device)
|
163 |
+
|
164 |
+
return wav_tensor, semantic_tensor
|
165 |
+
|
166 |
+
|
167 |
+
def wav_semantic_collate_fn(
|
168 |
+
batch: List[Tuple[torch.Tensor, torch.Tensor]],
|
169 |
+
sos_token: int = SOS_TOKEN, # Adjust based on your vocab
|
170 |
+
eos_token: int = EOS_TOKEN, # Adjust based on your vocab
|
171 |
+
padding_token: int = SEMANTIC_PADDING_TOKEN, # Adjust based on your vocab
|
172 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
173 |
+
"""
|
174 |
+
Collate function for wav and semantic token pairs, adding <SOS> and <EOS> to targets.
|
175 |
+
|
176 |
+
Args:
|
177 |
+
batch: List of (wav_tensor, semantic_tensor) tuples.
|
178 |
+
sos_token: Index of the <SOS> token.
|
179 |
+
eos_token: Index of the <EOS> token.
|
180 |
+
padding_token: Index of the padding token.
|
181 |
+
|
182 |
+
Returns:
|
183 |
+
Tuple of (padded_wavs, padded_targets, wav_lengths, target_lengths).
|
184 |
+
- padded_wavs: [B, max_wav_len]
|
185 |
+
- padded_targets: [B, max_target_len] with <SOS> and <EOS>
|
186 |
+
- wav_lengths: [B] (original wav lengths)
|
187 |
+
- target_lengths: [B] (original semantic lengths + 2 for <SOS> and <EOS>)
|
188 |
+
"""
|
189 |
+
waves, semantics = zip(*batch)
|
190 |
+
# Add <SOS> and <EOS> to each semantic sequence
|
191 |
+
semantics_with_tokens = [
|
192 |
+
torch.cat(
|
193 |
+
[
|
194 |
+
torch.tensor([sos_token], dtype=torch.long, device=semantic.device),
|
195 |
+
semantic,
|
196 |
+
torch.tensor([eos_token], dtype=torch.long, device=semantic.device),
|
197 |
+
]
|
198 |
+
)
|
199 |
+
for semantic in semantics
|
200 |
+
]
|
201 |
+
|
202 |
+
# Compute lengths *after* adding <SOS> and <EOS>
|
203 |
+
wav_lengths = torch.tensor([wav.size(0) for wav in waves], dtype=torch.long)
|
204 |
+
target_lengths = torch.tensor(
|
205 |
+
[semantic.size(0) for semantic in semantics_with_tokens], dtype=torch.long
|
206 |
+
)
|
207 |
+
|
208 |
+
# Pad waves and targets to max length in batch
|
209 |
+
max_wav_len = max(wav_lengths).item()
|
210 |
+
max_target_len = max(target_lengths).item()
|
211 |
+
|
212 |
+
padded_wavs = torch.zeros(size=(len(waves), max_wav_len), device=waves[0].device)
|
213 |
+
padded_targets = torch.full(
|
214 |
+
size=(len(semantics), max_target_len),
|
215 |
+
fill_value=padding_token,
|
216 |
+
dtype=torch.long,
|
217 |
+
device=semantics[0].device,
|
218 |
+
)
|
219 |
+
|
220 |
+
for i, (wav, semantic) in enumerate(zip(waves, semantics_with_tokens)):
|
221 |
+
padded_wavs[i, : wav.size(0)] = wav
|
222 |
+
padded_targets[i, : semantic.size(0)] = semantic
|
223 |
+
|
224 |
+
return padded_wavs, padded_targets, wav_lengths, target_lengths
|
225 |
+
|
226 |
+
|
227 |
+
def load_train_val_dataloaders(
|
228 |
+
dataset: WavSemanticDataset,
|
229 |
+
train_ratio: float,
|
230 |
+
batch_size: int,
|
231 |
+
target_sample_rate: int = 16000,
|
232 |
+
noises: List[np.ndarray] = None,
|
233 |
+
augment_prob: float = 0.5,
|
234 |
+
device: Optional[torch.device] = None,
|
235 |
+
) -> Tuple[DataLoader, DataLoader]:
|
236 |
+
"""
|
237 |
+
Load train and validation DataLoaders from a WavSemanticDataset with dynamic batch padding.
|
238 |
+
|
239 |
+
Args:
|
240 |
+
dataset: The WavSemanticDataset instance to split and load.
|
241 |
+
train_ratio: Fraction of data to use for training (0 to 1).
|
242 |
+
batch_size: Number of samples per batch.
|
243 |
+
target_sample_rate: Target sample rate for resampling (default: 16000 Hz).
|
244 |
+
device: Optional device to move tensors to (default: None, stays on CPU).
|
245 |
+
|
246 |
+
Returns:
|
247 |
+
Tuple of (train_dataloader, val_dataloader).
|
248 |
+
"""
|
249 |
+
# Split dataset into train and val
|
250 |
+
total_samples = len(dataset.data)
|
251 |
+
train_size = int(train_ratio * total_samples)
|
252 |
+
val_size = total_samples - train_size
|
253 |
+
train_data, val_data = random_split(dataset.data, [train_size, val_size])
|
254 |
+
|
255 |
+
# Create datasets without fixed max_sequence_length
|
256 |
+
train_dataset = WavSemanticTorchDataset(
|
257 |
+
samples=train_data,
|
258 |
+
orig_sample_rate=dataset.sample_rate,
|
259 |
+
target_sample_rate=target_sample_rate,
|
260 |
+
device=device,
|
261 |
+
noises=noises,
|
262 |
+
augment_prob=augment_prob,
|
263 |
+
)
|
264 |
+
val_dataset = WavSemanticTorchDataset(
|
265 |
+
samples=val_data,
|
266 |
+
orig_sample_rate=dataset.sample_rate,
|
267 |
+
target_sample_rate=target_sample_rate,
|
268 |
+
device=device,
|
269 |
+
noises=noises,
|
270 |
+
augment_prob=augment_prob,
|
271 |
+
)
|
272 |
+
|
273 |
+
# Create dataloaders with custom collate function
|
274 |
+
train_dataloader = DataLoader(
|
275 |
+
train_dataset,
|
276 |
+
batch_size=batch_size,
|
277 |
+
shuffle=True,
|
278 |
+
num_workers=0, # Increase if you have multiple cores
|
279 |
+
collate_fn=wav_semantic_collate_fn,
|
280 |
+
)
|
281 |
+
val_dataloader = DataLoader(
|
282 |
+
val_dataset,
|
283 |
+
batch_size=batch_size,
|
284 |
+
shuffle=False,
|
285 |
+
num_workers=0,
|
286 |
+
collate_fn=wav_semantic_collate_fn,
|
287 |
+
)
|
288 |
+
|
289 |
+
return train_dataloader, val_dataloader
|
290 |
+
|
291 |
+
|
292 |
+
def train_hubert_one_epoch(
|
293 |
+
model: nn.Module,
|
294 |
+
optimizer: torch.optim.Optimizer,
|
295 |
+
criterion: nn.CrossEntropyLoss,
|
296 |
+
train_dataloader: DataLoader,
|
297 |
+
grad_scaler: torch.cuda.amp.GradScaler,
|
298 |
+
device: torch.device,
|
299 |
+
progress_bar: Optional[tqdm] = None,
|
300 |
+
enable_autocast: bool = False,
|
301 |
+
) -> Dict[str, float]:
|
302 |
+
"""
|
303 |
+
Train the HuBERT model for one epoch using mixed-precision training with CrossEntropyLoss.
|
304 |
+
|
305 |
+
Args:
|
306 |
+
model: The HuBERT model with Transformer decoder.
|
307 |
+
optimizer: Optimizer for updating model parameters.
|
308 |
+
criterion: CrossEntropyLoss function.
|
309 |
+
train_dataloader: DataLoader for training data.
|
310 |
+
grad_scaler: Gradient scaler for mixed-precision training.
|
311 |
+
device: Device to train on (e.g., 'cuda', 'mps', 'cpu').
|
312 |
+
progress_bar: Optional tqdm progress bar.
|
313 |
+
|
314 |
+
Returns:
|
315 |
+
Dict with 'loss' metric.
|
316 |
+
"""
|
317 |
+
model.train()
|
318 |
+
total_loss = 0.0
|
319 |
+
for batch in train_dataloader:
|
320 |
+
# DataLoader already moves data to device
|
321 |
+
waves, targets = batch[0], batch[1]
|
322 |
+
optimizer.zero_grad()
|
323 |
+
with torch.autocast(
|
324 |
+
device_type=device.type, dtype=torch.bfloat16, enabled=enable_autocast
|
325 |
+
):
|
326 |
+
|
327 |
+
logits: torch.Tensor = model(waves, targets)
|
328 |
+
|
329 |
+
loss = criterion(logits[:, :-1, :].transpose(1, 2), targets[:, 1:])
|
330 |
+
|
331 |
+
total_loss += loss.detach().item()
|
332 |
+
|
333 |
+
# Mixed precision with scaler (remove scaler if autocast is disabled)
|
334 |
+
grad_scaler.scale(loss).backward()
|
335 |
+
grad_scaler.step(optimizer)
|
336 |
+
grad_scaler.update()
|
337 |
+
|
338 |
+
if progress_bar is not None:
|
339 |
+
progress_bar.update(1)
|
340 |
+
|
341 |
+
avg_loss = total_loss / len(train_dataloader)
|
342 |
+
return {"loss": avg_loss}
|
343 |
+
|
344 |
+
|
345 |
+
def eval_hubert(
|
346 |
+
model: nn.Module,
|
347 |
+
criterion: nn.CrossEntropyLoss,
|
348 |
+
val_dataloader: DataLoader,
|
349 |
+
device: torch.device,
|
350 |
+
sos_token: int = SOS_TOKEN,
|
351 |
+
eos_token: int = EOS_TOKEN,
|
352 |
+
padding_token: int = SEMANTIC_PADDING_TOKEN,
|
353 |
+
) -> Dict[str, float]:
|
354 |
+
"""
|
355 |
+
Evaluate the updated HuBERT model with Transformer decoder on the validation set.
|
356 |
+
|
357 |
+
Args:
|
358 |
+
model: The HuBERT model with Transformer decoder.
|
359 |
+
criterion: CrossEntropyLoss function.
|
360 |
+
val_dataloader: DataLoader for validation data (waves, targets).
|
361 |
+
device: Device to evaluate on.
|
362 |
+
sos_token: Index of the <SOS> token.
|
363 |
+
eos_token: Index of the <EOS> token.
|
364 |
+
padding_token: Index of the padding token.
|
365 |
+
|
366 |
+
Returns:
|
367 |
+
Dict with 'loss', 'accuracy', and 'num_tokens' metrics.
|
368 |
+
"""
|
369 |
+
model.eval()
|
370 |
+
total_loss = 0.0
|
371 |
+
total_correct = 0
|
372 |
+
total_tokens = 0
|
373 |
+
num_batches = 0
|
374 |
+
|
375 |
+
for batch in val_dataloader:
|
376 |
+
# targets: [B, T'] with <SOS> and <EOS>
|
377 |
+
waves, targets = batch[0].to(device), batch[1].to(device)
|
378 |
+
|
379 |
+
with torch.no_grad(), torch.autocast(
|
380 |
+
device_type=device.type, dtype=torch.bfloat16
|
381 |
+
):
|
382 |
+
# [B, T', semantic_vocab_size]
|
383 |
+
# transformers use batch_first=True
|
384 |
+
# targets is a tensor of [B, T'], all including [SOS] and [EOS] tokens
|
385 |
+
logits: torch.Tensor = model(waves, targets)
|
386 |
+
|
387 |
+
# remove the last token predictions from the logits
|
388 |
+
# remove the first token, which is SOS token from the targets
|
389 |
+
# transpose the logits tensor from (B, T, C) to (B, C, T)
|
390 |
+
loss = criterion(logits[:, :-1, :].transpose(1, 2), targets[:, 1:])
|
391 |
+
|
392 |
+
# Calculate accuracy (ignoring padding tokens)
|
393 |
+
preds = logits.argmax(dim=-1)[:, :-1]
|
394 |
+
target_shifted = targets[:, 1:]
|
395 |
+
mask = target_shifted != padding_token
|
396 |
+
total_correct += (preds[mask] == target_shifted[mask]).sum().item()
|
397 |
+
total_tokens += mask.sum().item()
|
398 |
+
|
399 |
+
total_loss += loss.item()
|
400 |
+
num_batches += 1
|
401 |
+
|
402 |
+
avg_loss = total_loss / num_batches
|
403 |
+
accuracy = total_correct / total_tokens if total_tokens > 0 else 0.0
|
404 |
+
|
405 |
+
return {"loss": avg_loss, "accuracy": accuracy, "num_tokens": total_tokens}
|
406 |
+
|
407 |
+
|
408 |
+
def _load_noise_dataset(data_path: str, target_sample_rate: int) -> List[np.ndarray]:
|
409 |
+
data = []
|
410 |
+
# Add more extensions as needed ".flac", ".ogg", ".aiff"
|
411 |
+
audio_extensions = (".wav", ".mp3")
|
412 |
+
|
413 |
+
# Walk through all directories and subdirectories
|
414 |
+
for root, dirs, files in os.walk(data_path):
|
415 |
+
for filename in files:
|
416 |
+
# Check if the file has an audio extension
|
417 |
+
if filename.lower().endswith(audio_extensions):
|
418 |
+
filepath = os.path.join(root, filename)
|
419 |
+
try:
|
420 |
+
audio = read_audio_file(
|
421 |
+
filepath,
|
422 |
+
target_sample_rate=target_sample_rate,
|
423 |
+
channels=1,
|
424 |
+
normalize=False,
|
425 |
+
)
|
426 |
+
data.append(audio)
|
427 |
+
except Exception as e:
|
428 |
+
print(f"Warning: Could not load {filepath}: {str(e)}")
|
429 |
+
continue
|
430 |
+
|
431 |
+
if len(data) == 0:
|
432 |
+
raise RuntimeError(f"No audio files found in {data_path} or its subdirectories")
|
433 |
+
|
434 |
+
return data
|
435 |
+
|
436 |
+
|
437 |
+
def train_hubert_quantizer(
|
438 |
+
model: nn.Module,
|
439 |
+
model_config: Dict[str, Any],
|
440 |
+
lr: float,
|
441 |
+
num_epoch: int,
|
442 |
+
train_ratio: float = 0.8,
|
443 |
+
batch_size: int = 64,
|
444 |
+
data_path: str = "./wav_semantic_dataset",
|
445 |
+
checkpoint_path: str = "./checkpoints",
|
446 |
+
save_checkpoint_every: int = 2,
|
447 |
+
enable_grad_scaler: bool = False,
|
448 |
+
augment_data_with_noise: bool = False,
|
449 |
+
augment_prob: float = 0.5,
|
450 |
+
noise_data_path: str = "./noise_dataset",
|
451 |
+
publish_hf: bool = False,
|
452 |
+
publish_to_repo: str = "",
|
453 |
+
num_samples: int = 5000,
|
454 |
+
device: torch.device = "cuda",
|
455 |
+
) -> nn.Module:
|
456 |
+
"""
|
457 |
+
Train a HuBERT model with mixed-precision training and save checkpoints.
|
458 |
+
|
459 |
+
Args:
|
460 |
+
model: The HuBERT model to train.
|
461 |
+
lr: Learning rate for the optimizer.
|
462 |
+
num_epoch: Number of epochs to train.
|
463 |
+
train_ratio: Fraction of data for training.
|
464 |
+
batch_size: Batch size for DataLoaders.
|
465 |
+
data_path: Path to the saved dataset.
|
466 |
+
checkpoint_path: Directory to save checkpoints.
|
467 |
+
save_checkpoint_every: Save checkpoint every N epochs.
|
468 |
+
augment_data_with_noise: whether to add random noise to training audio
|
469 |
+
augment_prob: probability of a sample will be augmented with noise
|
470 |
+
num_samples: maximum number of samples to load from the dataset
|
471 |
+
Returns:
|
472 |
+
The trained model.
|
473 |
+
"""
|
474 |
+
|
475 |
+
# else "mps" if torch.backends.mps.is_available()
|
476 |
+
# mix precision training doesn't work with mps device at the grad_scaler.step(optimizer) step
|
477 |
+
# for testing just run on cpu
|
478 |
+
model.to(device)
|
479 |
+
|
480 |
+
# Load dataset and create dataloaders
|
481 |
+
dataset = WavSemanticDataset.load(data_path, num_samples=num_samples)
|
482 |
+
noises = None
|
483 |
+
if augment_data_with_noise:
|
484 |
+
logger.info(f"reading noise data from {noise_data_path}")
|
485 |
+
noises = _load_noise_dataset(noise_data_path, target_sample_rate=16000)
|
486 |
+
|
487 |
+
train_dataloader, val_dataloader = load_train_val_dataloaders(
|
488 |
+
dataset,
|
489 |
+
train_ratio=train_ratio,
|
490 |
+
batch_size=batch_size,
|
491 |
+
target_sample_rate=HUBERT_SAMPLE_RATE,
|
492 |
+
noises=noises,
|
493 |
+
augment_prob=augment_prob,
|
494 |
+
device=device,
|
495 |
+
)
|
496 |
+
|
497 |
+
optimizer = Adam(model.parameters(), lr=lr)
|
498 |
+
criterion = nn.CrossEntropyLoss(ignore_index=SEMANTIC_PADDING_TOKEN)
|
499 |
+
grad_scaler = torch.amp.GradScaler(device.type, enabled=enable_grad_scaler)
|
500 |
+
progress_bar = tqdm(total=num_epoch * len(train_dataloader), desc="Training HuBERT")
|
501 |
+
# scheduler = LinearLR(
|
502 |
+
# optimizer, start_factor=1, end_factor=0.5, total_iters=(num_epoch / 2)
|
503 |
+
# )
|
504 |
+
scheduler = None
|
505 |
+
Path(checkpoint_path).mkdir(parents=True, exist_ok=True)
|
506 |
+
|
507 |
+
for epoch in range(num_epoch):
|
508 |
+
train_result = train_hubert_one_epoch(
|
509 |
+
model=model,
|
510 |
+
optimizer=optimizer,
|
511 |
+
criterion=criterion,
|
512 |
+
train_dataloader=train_dataloader,
|
513 |
+
grad_scaler=grad_scaler,
|
514 |
+
device=device,
|
515 |
+
progress_bar=progress_bar,
|
516 |
+
enable_autocast=enable_grad_scaler,
|
517 |
+
)
|
518 |
+
with torch.no_grad():
|
519 |
+
eval_result = eval_hubert(
|
520 |
+
model=model,
|
521 |
+
criterion=criterion,
|
522 |
+
val_dataloader=val_dataloader,
|
523 |
+
device=device,
|
524 |
+
)
|
525 |
+
|
526 |
+
if scheduler is not None:
|
527 |
+
scheduler.step()
|
528 |
+
|
529 |
+
logger.info(
|
530 |
+
f"Epoch {epoch + 1}/{num_epoch}, Train: {train_result}, Eval: {eval_result}"
|
531 |
+
)
|
532 |
+
|
533 |
+
if (epoch + 1) % save_checkpoint_every == 0:
|
534 |
+
checkpoint_file = os.path.join(
|
535 |
+
checkpoint_path,
|
536 |
+
f"hubert_epoch_{epoch + 1}_{datetime.now().strftime('%Y_%m_%d_%H_%M')}_eval_loss_{eval_result.get('loss', 0)}_acc_{eval_result.get('accuracy', 0)}.pt",
|
537 |
+
)
|
538 |
+
torch.save(
|
539 |
+
{ # should have save the model configuration for later loading
|
540 |
+
"epoch": epoch + 1,
|
541 |
+
"model_state_dict": model.state_dict(),
|
542 |
+
# "optimizer_state_dict": optimizer.state_dict(),
|
543 |
+
"train_result": train_result,
|
544 |
+
"eval_result": eval_result,
|
545 |
+
"config": model_config,
|
546 |
+
},
|
547 |
+
checkpoint_file,
|
548 |
+
)
|
549 |
+
logger.info(f"Saved checkpoint to {checkpoint_file}")
|
550 |
+
|
551 |
+
if publish_hf:
|
552 |
+
upload_file_to_hf(checkpoint_file, publish_to_repo, "model")
|
553 |
+
|
554 |
+
progress_bar.close()
|
555 |
+
return model
|
core/utils/__init__.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from core.utils.audio import *
|
2 |
+
|
3 |
+
from core.utils.text import *
|
4 |
+
|
5 |
+
from core.utils.read_write_files import *
|
6 |
+
|
7 |
+
from core.utils.huggingface import *
|
core/utils/audio.py
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Helpful functions to process audio
|
3 |
+
"""
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import soundfile as sf
|
7 |
+
|
8 |
+
from typing_extensions import Annotated, Literal, Optional
|
9 |
+
import torchaudio
|
10 |
+
import torch
|
11 |
+
|
12 |
+
AudioChannel = Literal[1, 2]
|
13 |
+
|
14 |
+
|
15 |
+
def read_audio_file(
|
16 |
+
path: str,
|
17 |
+
target_sample_rate: int = 16000,
|
18 |
+
channels: int = 1,
|
19 |
+
normalize: bool = True,
|
20 |
+
max_duration: Optional[float] = None,
|
21 |
+
) -> np.ndarray:
|
22 |
+
"""Read and resample audio file
|
23 |
+
If target_sample_rate is different than the audio's sample rate, this function will resample it
|
24 |
+
If GPU is available, the resampling will be on GPU.
|
25 |
+
|
26 |
+
Args:
|
27 |
+
path: Path to the audio file (supports WAV, FLAC, OGG)
|
28 |
+
target_sample_rate: Target sample rate (default: 24000)
|
29 |
+
channels: Number of output channels (1 for mono, 2 for stereo)
|
30 |
+
normalize: Whether to normalize audio to [-1, 1]
|
31 |
+
max_duration: Maximum duration in seconds (truncates longer files)
|
32 |
+
device: Device to process on ("cuda" or "cpu", defaults to cuda if available)
|
33 |
+
|
34 |
+
Returns:
|
35 |
+
np.ndarray: Processed audio samples as a numpy array
|
36 |
+
|
37 |
+
Raises:
|
38 |
+
RuntimeError: If the file cannot be read or processing fails
|
39 |
+
"""
|
40 |
+
try:
|
41 |
+
# Load audio file with torchaudio
|
42 |
+
waveform, original_sample_rate = torchaudio.load(path) # [channels, samples]
|
43 |
+
|
44 |
+
# Truncate to max_duration before resampling
|
45 |
+
if max_duration is not None:
|
46 |
+
max_samples = int(max_duration * original_sample_rate)
|
47 |
+
if waveform.size(1) > max_samples:
|
48 |
+
waveform = waveform[:, :max_samples]
|
49 |
+
|
50 |
+
# Downmix to desired channels
|
51 |
+
if waveform.size(0) > channels:
|
52 |
+
if channels == 1:
|
53 |
+
waveform = waveform.mean(dim=0, keepdim=True) # Mono: average channels
|
54 |
+
elif channels == 2:
|
55 |
+
waveform = waveform[:2, :] # Stereo: take first 2 channels
|
56 |
+
|
57 |
+
# Resample if needed
|
58 |
+
if original_sample_rate != target_sample_rate:
|
59 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
60 |
+
waveform = waveform.to(device)
|
61 |
+
resampler = torchaudio.transforms.Resample(
|
62 |
+
orig_freq=original_sample_rate,
|
63 |
+
new_freq=target_sample_rate,
|
64 |
+
resampling_method="sinc_interp_kaiser", # Fast and high-quality
|
65 |
+
).to(device)
|
66 |
+
waveform = resampler(waveform)
|
67 |
+
|
68 |
+
# Normalize to [-1, 1] if requested
|
69 |
+
if normalize:
|
70 |
+
max_val = waveform.abs().max()
|
71 |
+
if max_val > 0:
|
72 |
+
waveform = waveform / max_val
|
73 |
+
|
74 |
+
# Move back to CPU and convert to numpy
|
75 |
+
data = waveform.cpu().numpy()
|
76 |
+
|
77 |
+
# Ensure correct shape (remove extra dim if mono)
|
78 |
+
if channels == 1 and data.shape[0] == 1:
|
79 |
+
data = data[0, :]
|
80 |
+
|
81 |
+
return data
|
82 |
+
|
83 |
+
except Exception as e:
|
84 |
+
raise RuntimeError(f"Failed to read audio file {path}: {str(e)}")
|
85 |
+
|
86 |
+
|
87 |
+
def save_audio_file(
|
88 |
+
audio_array: np.ndarray, sample_rate: int, file_path: str, format="WAV"
|
89 |
+
):
|
90 |
+
"""
|
91 |
+
Save an audio array to a file.
|
92 |
+
|
93 |
+
Parameters:
|
94 |
+
- audio_array: numpy array or list containing the audio samples
|
95 |
+
- sample_rate: int, the sample rate of the audio (e.g., 44100 Hz)
|
96 |
+
- file_path: str, path where the file will be saved (e.g., 'output.wav')
|
97 |
+
- format: str, audio file format (e.g., 'WAV', 'FLAC', 'OGG'), default is 'WAV'
|
98 |
+
"""
|
99 |
+
try:
|
100 |
+
if not file_path.endswith(".wav"):
|
101 |
+
file_path += ".wav"
|
102 |
+
sf.write(file_path, audio_array, sample_rate, format=format)
|
103 |
+
except Exception as e:
|
104 |
+
print(f"Error saving audio file at {file_path}: {e}")
|
core/utils/huggingface.py
ADDED
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import sys
|
3 |
+
from typing import Optional, Literal
|
4 |
+
import os
|
5 |
+
import shutil
|
6 |
+
from zipfile import ZipFile
|
7 |
+
from pathlib import Path
|
8 |
+
from huggingface_hub import hf_hub_download, upload_file
|
9 |
+
|
10 |
+
# Set up logging
|
11 |
+
logging.basicConfig(
|
12 |
+
level=logging.INFO,
|
13 |
+
format="%(asctime)s - %(levelname)s - %(message)s",
|
14 |
+
handlers=[logging.StreamHandler(sys.stdout)],
|
15 |
+
)
|
16 |
+
logger = logging.getLogger(__name__)
|
17 |
+
|
18 |
+
__all__ = ["download_dataset_from_hf", "upload_file_to_hf", "download_file_from_hf"]
|
19 |
+
|
20 |
+
|
21 |
+
def download_dataset_from_hf(
|
22 |
+
repo_id: str,
|
23 |
+
filename: str,
|
24 |
+
dest_path: str,
|
25 |
+
token: str = None,
|
26 |
+
local_dir: str = "./downloads",
|
27 |
+
remove_downloaded_file: bool = True,
|
28 |
+
) -> None:
|
29 |
+
"""
|
30 |
+
Download a file from Hugging Face repository and unzip it to destination path
|
31 |
+
|
32 |
+
Args:
|
33 |
+
repo_id (str): Hugging Face repository ID (username/repo_name)
|
34 |
+
filename (str): Name of the file to download from the repository
|
35 |
+
dest_path (str): Destination path where contents will be unzipped
|
36 |
+
token (str, optional): Hugging Face token, if None will prompt for login
|
37 |
+
"""
|
38 |
+
# Ensure destination directory exists
|
39 |
+
os.makedirs(dest_path, exist_ok=True)
|
40 |
+
if token is None:
|
41 |
+
logger.info("reading HF_TOKEN variable from environment")
|
42 |
+
token = os.getenv("HF_TOKEN")
|
43 |
+
|
44 |
+
# Download the file
|
45 |
+
downloaded_file = hf_hub_download(
|
46 |
+
repo_id=repo_id,
|
47 |
+
filename=filename,
|
48 |
+
repo_type="dataset", # Specify dataset repository
|
49 |
+
local_dir=local_dir, # Temporary download location
|
50 |
+
token=token,
|
51 |
+
)
|
52 |
+
logger.info(f"Downloaded {filename} to {downloaded_file}")
|
53 |
+
|
54 |
+
# Check if it's a zip file
|
55 |
+
if filename.endswith(".zip"):
|
56 |
+
# Extract the zip file
|
57 |
+
with ZipFile(downloaded_file, "r") as zip_ref:
|
58 |
+
zip_ref.extractall(dest_path)
|
59 |
+
logger.info(f"Unzipped contents to {dest_path}")
|
60 |
+
|
61 |
+
# Clean up the downloaded zip file
|
62 |
+
if remove_downloaded_file:
|
63 |
+
os.remove(downloaded_file)
|
64 |
+
logger.info(f"Cleaned up temporary file: {downloaded_file}")
|
65 |
+
else:
|
66 |
+
# If not a zip, just move the file
|
67 |
+
final_path = os.path.join(dest_path, filename)
|
68 |
+
shutil.move(downloaded_file, final_path)
|
69 |
+
logger.info(f"Moved {filename} to {final_path}")
|
70 |
+
|
71 |
+
|
72 |
+
def download_file_from_hf(
|
73 |
+
repo_id: str,
|
74 |
+
repo_type: Literal["model", "dataset"],
|
75 |
+
filename: str,
|
76 |
+
dest_path: str,
|
77 |
+
token: str = None,
|
78 |
+
) -> None:
|
79 |
+
"""
|
80 |
+
Download a file from Hugging Face repository and unzip it to destination path
|
81 |
+
|
82 |
+
Args:
|
83 |
+
repo_id (str): Hugging Face repository ID (username/repo_name)
|
84 |
+
repo_type: model for model repo, dataset for dataset repo
|
85 |
+
filename (str): Name of the file to download from the repository
|
86 |
+
dest_path (str): Destination path where contents will be unzipped
|
87 |
+
token (str, optional): Hugging Face token, if None will prompt for login
|
88 |
+
|
89 |
+
"""
|
90 |
+
# Ensure destination directory exists
|
91 |
+
os.makedirs(dest_path, exist_ok=True)
|
92 |
+
if token is None:
|
93 |
+
logger.info("reading HF_TOKEN variable from environment")
|
94 |
+
token = os.getenv("HF_TOKEN")
|
95 |
+
|
96 |
+
# Download the file
|
97 |
+
downloaded_file = hf_hub_download(
|
98 |
+
repo_id=repo_id,
|
99 |
+
filename=filename,
|
100 |
+
repo_type=repo_type,
|
101 |
+
local_dir="./downloads", # Temporary download location
|
102 |
+
token=token,
|
103 |
+
)
|
104 |
+
logger.info(f"Downloaded {filename} to {downloaded_file}")
|
105 |
+
|
106 |
+
# Check if it's a zip file
|
107 |
+
if filename.endswith(".zip"):
|
108 |
+
# Extract the zip file
|
109 |
+
with ZipFile(downloaded_file, "r") as zip_ref:
|
110 |
+
zip_ref.extractall(dest_path)
|
111 |
+
logger.info(f"Unzipped contents to {dest_path}")
|
112 |
+
|
113 |
+
# Clean up the downloaded zip file
|
114 |
+
os.remove(downloaded_file)
|
115 |
+
logger.info(f"Cleaned up temporary file: {downloaded_file}")
|
116 |
+
else:
|
117 |
+
# If not a zip, just move the file
|
118 |
+
final_path = os.path.join(dest_path, filename)
|
119 |
+
shutil.move(downloaded_file, final_path)
|
120 |
+
logger.info(f"Moved {filename} to {final_path}")
|
121 |
+
|
122 |
+
|
123 |
+
def upload_file_to_hf(
|
124 |
+
local_file_path: str,
|
125 |
+
repo_id: str,
|
126 |
+
repo_type: Literal["model", "dataset"],
|
127 |
+
token: Optional[str] = None,
|
128 |
+
path_in_repo: Optional[str] = None,
|
129 |
+
commit_message: str = "Upload file",
|
130 |
+
) -> None:
|
131 |
+
"""
|
132 |
+
Upload a file to Hugging Face hub.
|
133 |
+
|
134 |
+
Args:
|
135 |
+
local_file_path (str): Path to the local .pt checkpoint file
|
136 |
+
repo_id (str): Repository ID in format "username/repo_name"
|
137 |
+
repo_type (str, optional): Type of repository, either "model" or "dataset"
|
138 |
+
token (str): Hugging Face authentication token. Read from environment variable HF_TOKEN if don't provide
|
139 |
+
path_in_repo (str, optional): Destination path in the repository.
|
140 |
+
Defaults to the filename from local_checkpoint_path
|
141 |
+
commit_message (str, optional): Commit message for the upload
|
142 |
+
|
143 |
+
Raises:
|
144 |
+
FileNotFoundError: If the checkpoint file doesn't exist
|
145 |
+
ValueError: If the repository ID is invalid
|
146 |
+
"""
|
147 |
+
# Validate file exists
|
148 |
+
if not os.path.isfile(local_file_path):
|
149 |
+
raise FileNotFoundError(f"File not found: {local_file_path}")
|
150 |
+
|
151 |
+
# Use filename as default path_in_repo if not specified
|
152 |
+
if path_in_repo is None:
|
153 |
+
path_in_repo = Path(local_file_path).name
|
154 |
+
|
155 |
+
if token is None:
|
156 |
+
logger.info("reading HF_TOKEN variable from environment")
|
157 |
+
token = os.getenv("HF_TOKEN")
|
158 |
+
if token is None:
|
159 |
+
raise RuntimeError("not found HF_TOKEN variable from environment")
|
160 |
+
|
161 |
+
upload_file(
|
162 |
+
path_or_fileobj=local_file_path,
|
163 |
+
path_in_repo=path_in_repo,
|
164 |
+
repo_id=repo_id,
|
165 |
+
repo_type=repo_type,
|
166 |
+
token=token,
|
167 |
+
commit_message=commit_message,
|
168 |
+
)
|
169 |
+
logger.info(f"Successfully uploaded {local_file_path} to {repo_id}/{path_in_repo}")
|
core/utils/read_write_files.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import zipfile
|
3 |
+
|
4 |
+
|
5 |
+
def zip_folder(folder_path: str, output_path: str) -> bool:
|
6 |
+
"""
|
7 |
+
Zip a folder and its contents to a zip file.
|
8 |
+
|
9 |
+
Args:
|
10 |
+
folder_path (str): Path to the folder to be zipped
|
11 |
+
output_path (str): Path where the zip file will be created
|
12 |
+
|
13 |
+
Returns:
|
14 |
+
bool: True if successful, False otherwise
|
15 |
+
"""
|
16 |
+
try:
|
17 |
+
# Ensure the folder exists
|
18 |
+
if not os.path.isdir(folder_path):
|
19 |
+
print(f"Error: {folder_path} is not a valid directory")
|
20 |
+
return False
|
21 |
+
|
22 |
+
# Get the absolute path of the folder
|
23 |
+
abs_folder_path = os.path.abspath(folder_path)
|
24 |
+
|
25 |
+
# Create a ZipFile object in write mode
|
26 |
+
with zipfile.ZipFile(output_path, "w", zipfile.ZIP_DEFLATED) as zipf:
|
27 |
+
# Walk through the folder
|
28 |
+
for root, dirs, files in os.walk(abs_folder_path):
|
29 |
+
for file in files:
|
30 |
+
# Get the absolute path of the file
|
31 |
+
abs_file_path = os.path.join(root, file)
|
32 |
+
|
33 |
+
# Calculate relative path for the file inside the zip
|
34 |
+
rel_path = os.path.relpath(
|
35 |
+
abs_file_path, os.path.dirname(abs_folder_path)
|
36 |
+
)
|
37 |
+
|
38 |
+
# Add file to zip
|
39 |
+
zipf.write(abs_file_path, rel_path)
|
40 |
+
|
41 |
+
print(f"Successfully created zip file at {output_path}")
|
42 |
+
return True
|
43 |
+
|
44 |
+
except Exception as e:
|
45 |
+
print(f"Error creating zip file: {e}")
|
46 |
+
return False
|
core/utils/text.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
def normalize_whitespace(text: str) -> str:
|
2 |
+
"""
|
3 |
+
Normalize whitespace in text by:
|
4 |
+
1. Removing leading and trailing whitespace
|
5 |
+
2. Replacing any sequence of whitespace characters with a single space
|
6 |
+
|
7 |
+
Args:
|
8 |
+
text: Input string to normalize
|
9 |
+
|
10 |
+
Returns:
|
11 |
+
String with normalized whitespace
|
12 |
+
"""
|
13 |
+
return ' '.join(text.split())
|
event_handlers.py
ADDED
@@ -0,0 +1,436 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Tuple, Optional, Dict, Any
|
2 |
+
import traceback
|
3 |
+
import torch
|
4 |
+
import gradio as gr
|
5 |
+
import numpy as np
|
6 |
+
import time
|
7 |
+
import os
|
8 |
+
import re
|
9 |
+
import wave
|
10 |
+
import contextlib
|
11 |
+
import logging
|
12 |
+
import pandas as pd
|
13 |
+
import gc
|
14 |
+
|
15 |
+
import nltk
|
16 |
+
|
17 |
+
nltk.download("punkt")
|
18 |
+
from nltk.tokenize import sent_tokenize
|
19 |
+
|
20 |
+
from core.data_model import AudioFile
|
21 |
+
from core.bark.voice_clone import create_bark_prompt
|
22 |
+
from core.bark.generate_audio import generate_audio
|
23 |
+
from core.data_model import BarkPrompt, BarkGenerationConfig
|
24 |
+
from core.utils.audio import save_audio_file
|
25 |
+
from config import *
|
26 |
+
|
27 |
+
# Set up logging
|
28 |
+
logging.basicConfig(
|
29 |
+
level=logging.INFO,
|
30 |
+
format="%(asctime)s - %(levelname)s - %(message)s",
|
31 |
+
)
|
32 |
+
logger = logging.getLogger(__name__)
|
33 |
+
|
34 |
+
|
35 |
+
# return list of available devices and the best device to be used as default for all inference
|
36 |
+
def get_available_torch_devices() -> Tuple[List[str], str]:
|
37 |
+
devices = ["cpu"]
|
38 |
+
best_device = "cpu"
|
39 |
+
if torch.mps.is_available():
|
40 |
+
devices.append("mps")
|
41 |
+
best_device = "mps"
|
42 |
+
if torch.cuda.is_available():
|
43 |
+
devices.append("cuda")
|
44 |
+
best_device = "cuda"
|
45 |
+
|
46 |
+
return devices, best_device
|
47 |
+
|
48 |
+
|
49 |
+
# --- Helper Functions ---
|
50 |
+
# (Keep get_wav_duration, load_existing_audio, get_safe_filename,
|
51 |
+
# generate_sine_wave, save_audio, parse_text_prompts, get_available_prompts,
|
52 |
+
# create_audio_prompt as they are, they are mostly backend logic)
|
53 |
+
def get_wav_duration(filepath):
|
54 |
+
"""Gets the duration of a WAV file in seconds."""
|
55 |
+
try:
|
56 |
+
with contextlib.closing(wave.open(filepath, "r")) as f:
|
57 |
+
frames = f.getnframes()
|
58 |
+
rate = f.getframerate()
|
59 |
+
if rate > 0:
|
60 |
+
duration = frames / float(rate)
|
61 |
+
return duration
|
62 |
+
else:
|
63 |
+
logger.info(f"Warning: Framerate is 0 for {filepath}")
|
64 |
+
return 0
|
65 |
+
except wave.Error as e:
|
66 |
+
logger.info(f"Warning: Could not read wave file header for {filepath}: {e}")
|
67 |
+
return 0
|
68 |
+
except Exception as e:
|
69 |
+
logger.info(f"Warning: Could not get duration for {filepath}: {e}")
|
70 |
+
return 0
|
71 |
+
|
72 |
+
|
73 |
+
def load_existing_audio() -> List[Dict[str, Any]]:
|
74 |
+
"""Scans the audio directory and loads metadata for existing WAV files."""
|
75 |
+
logger.info("\n--- Loading Existing Audio Files ---")
|
76 |
+
existing_files_metadata = []
|
77 |
+
if not os.path.isdir(GENERATED_AUDIO_DIR):
|
78 |
+
logger.info(f"Directory not found: {GENERATED_AUDIO_DIR}")
|
79 |
+
return []
|
80 |
+
|
81 |
+
try:
|
82 |
+
for filename in os.listdir(GENERATED_AUDIO_DIR):
|
83 |
+
if filename.lower().endswith(".wav"):
|
84 |
+
filepath = os.path.join(GENERATED_AUDIO_DIR, filename)
|
85 |
+
if not os.path.isfile(filepath):
|
86 |
+
continue
|
87 |
+
|
88 |
+
match = re.match(r"^(.*)_(\d{13})\.wav$", filename)
|
89 |
+
text_guess = "Unknown (from filename)"
|
90 |
+
timestamp_ms = 0
|
91 |
+
if match:
|
92 |
+
text_guess = match.group(1).replace("_", " ")
|
93 |
+
try:
|
94 |
+
timestamp_ms = int(match.group(2))
|
95 |
+
except ValueError:
|
96 |
+
timestamp_ms = 0
|
97 |
+
else:
|
98 |
+
text_guess = os.path.splitext(filename)[0].replace("_", " ")
|
99 |
+
|
100 |
+
timestamp_sec = (
|
101 |
+
timestamp_ms / 1000.0
|
102 |
+
if timestamp_ms > 0
|
103 |
+
else os.path.getmtime(filepath)
|
104 |
+
)
|
105 |
+
duration = get_wav_duration(filepath)
|
106 |
+
|
107 |
+
metadata = {
|
108 |
+
"text": text_guess,
|
109 |
+
"path": filepath,
|
110 |
+
"duration": duration,
|
111 |
+
"timestamp": timestamp_sec,
|
112 |
+
}
|
113 |
+
existing_files_metadata.append(metadata)
|
114 |
+
|
115 |
+
except Exception as e:
|
116 |
+
logger.error(f"Error loading existing audio files: {e}")
|
117 |
+
|
118 |
+
existing_files_metadata.sort(key=lambda x: x.get("timestamp", 0))
|
119 |
+
logger.info(
|
120 |
+
f"--- Finished Loading {len(existing_files_metadata)} Existing Files ---"
|
121 |
+
)
|
122 |
+
return existing_files_metadata
|
123 |
+
|
124 |
+
|
125 |
+
def get_safe_filename(base_name: str, extension: str, directory: str) -> str:
|
126 |
+
"""Creates a safe and unique filename in the target directory."""
|
127 |
+
safe_base = "".join(
|
128 |
+
c if c.isalnum() or c in ["_", "-"] else "_" for c in base_name[:50]
|
129 |
+
)
|
130 |
+
timestamp = int(time.time() * 1000)
|
131 |
+
filename = f"{safe_base}_{timestamp}.{extension}"
|
132 |
+
filepath = os.path.join(directory, filename)
|
133 |
+
counter = 1
|
134 |
+
while os.path.exists(filepath):
|
135 |
+
filename = f"{safe_base}_{timestamp}_{counter}.{extension}"
|
136 |
+
filepath = os.path.join(directory, filename)
|
137 |
+
counter += 1
|
138 |
+
return filepath
|
139 |
+
|
140 |
+
|
141 |
+
def update_audio_list(
|
142 |
+
newly_generated_metadata: List[Dict[str, Any]],
|
143 |
+
current_audio_list: List[Dict[str, Any]],
|
144 |
+
) -> List[Dict[str, Any]]:
|
145 |
+
"""Appends new metadata to the list and sorts it by timestamp."""
|
146 |
+
logger.info(f"\n--- Updating Audio List State ---")
|
147 |
+
if not isinstance(current_audio_list, list):
|
148 |
+
logger.info("Current audio list was not a list, initializing.")
|
149 |
+
current_audio_list = []
|
150 |
+
if not isinstance(newly_generated_metadata, list):
|
151 |
+
logger.info("Newly generated metadata is not a list, skipping update.")
|
152 |
+
return current_audio_list
|
153 |
+
|
154 |
+
logger.info(f"Current list size: {len(current_audio_list)}")
|
155 |
+
logger.info(f"Adding {len(newly_generated_metadata)} new items.")
|
156 |
+
updated_list = current_audio_list + newly_generated_metadata
|
157 |
+
updated_list.sort(key=lambda x: x.get("timestamp", 0))
|
158 |
+
logger.info(f"Updated list state size: {len(updated_list)}")
|
159 |
+
logger.info("--- Finished Updating Audio List State ---")
|
160 |
+
return updated_list
|
161 |
+
|
162 |
+
|
163 |
+
def format_audio_list_for_dataframe(audio_list: List[Dict[str, Any]]) -> pd.DataFrame:
|
164 |
+
"""Converts the list of audio metadata dicts into a pandas DataFrame for display."""
|
165 |
+
logger.info("\n--- Formatting List for DataFrame ---")
|
166 |
+
if not audio_list:
|
167 |
+
logger.info("Audio list is empty, returning empty DataFrame.")
|
168 |
+
# Return empty DataFrame with correct columns
|
169 |
+
return pd.DataFrame(columns=["File", "Prompt", "Duration (s)"])
|
170 |
+
|
171 |
+
display_data = []
|
172 |
+
for item in audio_list:
|
173 |
+
filepath = item.get("path", "N/A")
|
174 |
+
filename = os.path.basename(filepath) if filepath != "N/A" else "N/A"
|
175 |
+
# Truncate long text prompts for display in the table
|
176 |
+
text_prompt = item.get("text", "N/A")
|
177 |
+
display_text = (
|
178 |
+
(text_prompt[:75] + "...") if len(text_prompt) > 75 else text_prompt
|
179 |
+
)
|
180 |
+
duration = item.get("duration", 0)
|
181 |
+
display_data.append(
|
182 |
+
{
|
183 |
+
"File": filename,
|
184 |
+
"Prompt": display_text,
|
185 |
+
"Duration (s)": f"{duration:.2f}" if duration else "N/A",
|
186 |
+
# Store the full path implicitly by list order, not shown in df
|
187 |
+
}
|
188 |
+
)
|
189 |
+
|
190 |
+
df = pd.DataFrame(display_data)
|
191 |
+
logger.info(f"Created DataFrame with {len(df)} rows.")
|
192 |
+
logger.info("--- Finished Formatting List for DataFrame ---")
|
193 |
+
return df
|
194 |
+
|
195 |
+
|
196 |
+
def handle_row_selection(
|
197 |
+
audio_list: List[Dict[str, Any]], evt: gr.SelectData
|
198 |
+
) -> Tuple[Optional[str], int]:
|
199 |
+
"""
|
200 |
+
Handles the selection event from the DataFrame.
|
201 |
+
Updates the audio player with the selected file's path.
|
202 |
+
Returns the filepath and the selected index.
|
203 |
+
"""
|
204 |
+
logger.info("\n--- Handling Row Selection ---")
|
205 |
+
selected_index = evt.index[0] if evt.index else None # Get row index
|
206 |
+
logger.info(f"DataFrame row selected. Event data: {evt}")
|
207 |
+
|
208 |
+
if selected_index is not None and 0 <= selected_index < len(audio_list):
|
209 |
+
selected_item = audio_list[selected_index]
|
210 |
+
filepath = selected_item.get("path")
|
211 |
+
logger.info(f"Selected item at index {selected_index}: {selected_item}")
|
212 |
+
if filepath and os.path.exists(filepath):
|
213 |
+
logger.info(f"Updating audio player with: {filepath}")
|
214 |
+
logger.info("--- Finished Handling Row Selection (Success) ---")
|
215 |
+
return filepath, selected_index
|
216 |
+
else:
|
217 |
+
logger.info(f"File not found for selected item: {filepath}")
|
218 |
+
gr.Warning(
|
219 |
+
f"File not found for selected row: {os.path.basename(filepath or 'N/A')}"
|
220 |
+
)
|
221 |
+
logger.info("--- Finished Handling Row Selection (File Not Found) ---")
|
222 |
+
return None, selected_index # Keep index, but clear player
|
223 |
+
else:
|
224 |
+
logger.info("Invalid selection index or empty list.")
|
225 |
+
logger.info("--- Finished Handling Row Selection (Invalid Index) ---")
|
226 |
+
return None, -1 # Clear player and indicate no valid selection
|
227 |
+
|
228 |
+
|
229 |
+
def handle_delete_selected(
|
230 |
+
selected_index: int, current_audio_list: List[Dict[str, Any]]
|
231 |
+
) -> Tuple[List[Dict[str, Any]], int, Optional[str]]:
|
232 |
+
"""
|
233 |
+
Deletes the audio file corresponding to the selected index.
|
234 |
+
Updates the main audio list state.
|
235 |
+
Clears the selection index and audio player.
|
236 |
+
"""
|
237 |
+
logger.info("\n--- Handling Delete Selected ---")
|
238 |
+
logger.info(f"Attempting deletion for selected index: {selected_index}")
|
239 |
+
|
240 |
+
if (
|
241 |
+
selected_index is None
|
242 |
+
or selected_index < 0
|
243 |
+
or selected_index >= len(current_audio_list)
|
244 |
+
):
|
245 |
+
gr.Warning("No valid audio selected for deletion.")
|
246 |
+
logger.info("No valid index provided.")
|
247 |
+
# Return current list, clear index, clear player
|
248 |
+
return current_audio_list, -1, None
|
249 |
+
|
250 |
+
item_to_delete = current_audio_list[selected_index]
|
251 |
+
filepath_to_delete = item_to_delete.get("path")
|
252 |
+
logger.info(f"Item to delete: {item_to_delete}")
|
253 |
+
|
254 |
+
# Create the new list excluding the item
|
255 |
+
# Corrected slicing logic: include elements before and after the index
|
256 |
+
new_audio_list = (
|
257 |
+
current_audio_list[:selected_index] + current_audio_list[selected_index + 1 :]
|
258 |
+
)
|
259 |
+
logger.info(f"New list size after filtering: {len(new_audio_list)}")
|
260 |
+
|
261 |
+
# Try to delete the file from disk
|
262 |
+
deletion_successful_on_disk = False
|
263 |
+
try:
|
264 |
+
if filepath_to_delete and os.path.exists(filepath_to_delete):
|
265 |
+
os.remove(filepath_to_delete)
|
266 |
+
logger.info(f"Successfully deleted file: {filepath_to_delete}")
|
267 |
+
gr.Info(f"Deleted {os.path.basename(filepath_to_delete)}")
|
268 |
+
deletion_successful_on_disk = True
|
269 |
+
elif filepath_to_delete:
|
270 |
+
logger.info(f"File not found for deletion: {filepath_to_delete}")
|
271 |
+
gr.Warning("Audio entry removed from list, but file was not found on disk.")
|
272 |
+
deletion_successful_on_disk = True # Consider list update successful
|
273 |
+
else:
|
274 |
+
logger.info("Invalid filepath in selected item.")
|
275 |
+
gr.Warning("Could not delete: Invalid file path associated with selection.")
|
276 |
+
# Revert list change if filepath was invalid from the start? Or keep it removed?
|
277 |
+
# Let's keep it removed from the list for consistency.
|
278 |
+
deletion_successful_on_disk = True # Treat as success for list update
|
279 |
+
|
280 |
+
except OSError as e:
|
281 |
+
logger.info(f"Error deleting file {filepath_to_delete}: {e}")
|
282 |
+
traceback.logger.info_exc()
|
283 |
+
gr.Error(f"Error deleting file: {e}")
|
284 |
+
# If file deletion fails, we still return the updated list (item removed).
|
285 |
+
# If you want to revert the list change on OS error, return `current_audio_list` here.
|
286 |
+
|
287 |
+
logger.info("--- Finished Deleting Selected Item ---")
|
288 |
+
# Return the updated list, clear the selected index, clear the audio player
|
289 |
+
return new_audio_list, -1, None
|
290 |
+
|
291 |
+
|
292 |
+
def get_available_prompts() -> List[str]:
|
293 |
+
"""Loads available prompt file names."""
|
294 |
+
try:
|
295 |
+
prompts = [
|
296 |
+
f
|
297 |
+
for f in os.listdir(PROMPT_DIR)
|
298 |
+
if os.path.isfile(os.path.join(PROMPT_DIR, f))
|
299 |
+
and f.lower().endswith((".npz", ".npy", ".json"))
|
300 |
+
]
|
301 |
+
|
302 |
+
if len(prompts) == 0:
|
303 |
+
gr.Info("No prompts found.", duration=3)
|
304 |
+
|
305 |
+
return ["None"] + prompts
|
306 |
+
except Exception as e:
|
307 |
+
logger.info(f"Error loading prompts: {e}")
|
308 |
+
gr.Info(f"Error loading prompts {e}", duration=3, title="Error")
|
309 |
+
return ["None"]
|
310 |
+
|
311 |
+
|
312 |
+
def update_available_prompts() -> gr.update:
|
313 |
+
try:
|
314 |
+
prompts = [
|
315 |
+
f
|
316 |
+
for f in os.listdir(PROMPT_DIR)
|
317 |
+
if os.path.isfile(os.path.join(PROMPT_DIR, f))
|
318 |
+
and f.lower().endswith((".npz", ".npy", ".json"))
|
319 |
+
]
|
320 |
+
|
321 |
+
if len(prompts) == 0:
|
322 |
+
gr.Info("No prompts found.", duration=3)
|
323 |
+
|
324 |
+
return gr.update(choices=["None"] + prompts)
|
325 |
+
except Exception as e:
|
326 |
+
logger.info(f"Error loading prompts: {e}")
|
327 |
+
gr.Info(f"Error loading prompts {e}", duration=3, title="Error")
|
328 |
+
return gr.update()
|
329 |
+
|
330 |
+
|
331 |
+
def generate_batch_audio(
|
332 |
+
text: str,
|
333 |
+
semantic_temp: float,
|
334 |
+
coarse_temp: float,
|
335 |
+
fine_temp: float,
|
336 |
+
manual_seed: int,
|
337 |
+
model_type: str,
|
338 |
+
inference_device: str,
|
339 |
+
selected_prompt_name: Optional[str],
|
340 |
+
) -> Tuple[List[Dict[str, Any]], str]:
|
341 |
+
"""
|
342 |
+
Generates audio (sine wave) for each line of text input.
|
343 |
+
Returns metadata for generated files.
|
344 |
+
"""
|
345 |
+
gc.collect()
|
346 |
+
|
347 |
+
torch.manual_seed(manual_seed)
|
348 |
+
if not text:
|
349 |
+
gr.Warning("No valid text prompts provided.")
|
350 |
+
return []
|
351 |
+
|
352 |
+
generated_metadata = []
|
353 |
+
|
354 |
+
bark_prompt = None
|
355 |
+
if selected_prompt_name != "None":
|
356 |
+
gr.Info("Loading audio prompt...")
|
357 |
+
prompt_path = os.path.join(PROMPT_DIR, selected_prompt_name)
|
358 |
+
bark_prompt = BarkPrompt.load_prompt(
|
359 |
+
prompt_path, torch.device(inference_device)
|
360 |
+
)
|
361 |
+
|
362 |
+
generation_config = BarkGenerationConfig(
|
363 |
+
temperature=semantic_temp,
|
364 |
+
generate_coarse_temperature=coarse_temp,
|
365 |
+
generate_fine_temperature=fine_temp,
|
366 |
+
use_small_model=True if model_type == "small" else False,
|
367 |
+
)
|
368 |
+
|
369 |
+
# split the text into sentences
|
370 |
+
sentences = sent_tokenize(text)
|
371 |
+
|
372 |
+
gr.Info("Generating Audio....", duration=120)
|
373 |
+
waves = generate_audio(
|
374 |
+
texts=sentences,
|
375 |
+
prompt=bark_prompt,
|
376 |
+
generation_config=generation_config,
|
377 |
+
silent=True,
|
378 |
+
)
|
379 |
+
audio = np.concat(waves, axis=-1)
|
380 |
+
|
381 |
+
output_filepath = get_safe_filename(text, "wav", GENERATED_AUDIO_DIR)
|
382 |
+
save_audio_file(audio, DEFAULT_AUDIO_SAMPLE_RATE, output_filepath)
|
383 |
+
duration_sec = audio.shape[0] // DEFAULT_AUDIO_SAMPLE_RATE
|
384 |
+
metadata = {
|
385 |
+
"text": text,
|
386 |
+
"path": output_filepath,
|
387 |
+
"duration": duration_sec,
|
388 |
+
"timestamp": time.time(),
|
389 |
+
}
|
390 |
+
generated_metadata.append(metadata)
|
391 |
+
gr.Info("Done!", duration=5)
|
392 |
+
return generated_metadata
|
393 |
+
|
394 |
+
|
395 |
+
def create_audio_prompt(
|
396 |
+
uploaded_audio_file: Optional[str],
|
397 |
+
device: str,
|
398 |
+
progress: gr.Progress = gr.Progress(),
|
399 |
+
) -> gr.update:
|
400 |
+
"""Processes an uploaded audio file to create a voice prompt file (stub)."""
|
401 |
+
logger.info("\n--- Starting Prompt Creation ---")
|
402 |
+
if uploaded_audio_file is None or len(uploaded_audio_file) == 0:
|
403 |
+
gr.Warning("No audio file uploaded!")
|
404 |
+
return gr.update()
|
405 |
+
|
406 |
+
logger.info(f"Processing uploaded file: {uploaded_audio_file}")
|
407 |
+
|
408 |
+
try:
|
409 |
+
progress(0, desc="Starting prompt creation...")
|
410 |
+
new_prompt_filename = None
|
411 |
+
progress(0.2, desc="Extracting prompt features...")
|
412 |
+
audio_file = AudioFile(audio_file_path=uploaded_audio_file, max_duration=10)
|
413 |
+
prompt = create_bark_prompt(
|
414 |
+
audio_file=audio_file, temperature=1, eos_p=0.2, device=torch.device(device)
|
415 |
+
)
|
416 |
+
|
417 |
+
progress(0.8, desc="Saving prompt file...")
|
418 |
+
original_basename = os.path.splitext(os.path.basename(uploaded_audio_file))[0]
|
419 |
+
prompt_filepath = get_safe_filename(original_basename, "json", PROMPT_DIR)
|
420 |
+
new_prompt_filename = os.path.basename(prompt_filepath)
|
421 |
+
|
422 |
+
ok = prompt.save_prompt(prompt_filepath)
|
423 |
+
if ok:
|
424 |
+
progress(1.0, desc="Prompt creation complete.")
|
425 |
+
|
426 |
+
else:
|
427 |
+
progress(1.0, desc="Error when saving prompt")
|
428 |
+
|
429 |
+
new_choices = get_available_prompts()
|
430 |
+
|
431 |
+
return gr.update(choices=new_choices, value=new_prompt_filename)
|
432 |
+
|
433 |
+
except Exception as e:
|
434 |
+
logger.info(f"Error creating prompt: {e}")
|
435 |
+
gr.Error(f"Prompt creation failed: {e}")
|
436 |
+
return f"Error creating prompt: {e}", gr.update()
|
generate_audio_semantic_dataset.py
ADDED
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import logging
|
3 |
+
import os
|
4 |
+
from typing import Optional
|
5 |
+
|
6 |
+
from core.bark.generate_audio_semantic_dataset import (
|
7 |
+
generate_wav_semantic_dataset,
|
8 |
+
BarkGenerationConfig,
|
9 |
+
)
|
10 |
+
from core.utils import upload_file_to_hf, zip_folder
|
11 |
+
|
12 |
+
|
13 |
+
logging.basicConfig(
|
14 |
+
level=logging.INFO,
|
15 |
+
format="%(asctime)s - %(levelname)s - %(message)s",
|
16 |
+
)
|
17 |
+
logger = logging.getLogger(__name__)
|
18 |
+
|
19 |
+
|
20 |
+
def parse_dataset_args(args_list=None):
|
21 |
+
"""Parse arguments specific to dataset creation."""
|
22 |
+
parser = argparse.ArgumentParser(description="Audio Semantic Dataset Creation")
|
23 |
+
|
24 |
+
parser.add_argument(
|
25 |
+
"--text-file",
|
26 |
+
type=str,
|
27 |
+
default="data/test_data.txt",
|
28 |
+
help="Path to text file for dataset generation",
|
29 |
+
)
|
30 |
+
parser.add_argument(
|
31 |
+
"--batch-size",
|
32 |
+
type=int,
|
33 |
+
default=2,
|
34 |
+
help="Batch size for processing (default: 1)",
|
35 |
+
)
|
36 |
+
|
37 |
+
parser.add_argument(
|
38 |
+
"--output-dir",
|
39 |
+
type=str,
|
40 |
+
default="./dataset",
|
41 |
+
help="Output directory for generated files (default: ./dataset)",
|
42 |
+
)
|
43 |
+
parser.add_argument(
|
44 |
+
"--max-tokens",
|
45 |
+
type=int,
|
46 |
+
default=256,
|
47 |
+
help="Maximum tokens per example (default: 256)",
|
48 |
+
)
|
49 |
+
parser.add_argument(
|
50 |
+
"--use-small-model",
|
51 |
+
action="store_true",
|
52 |
+
help="Use small model for generation",
|
53 |
+
)
|
54 |
+
parser.add_argument(
|
55 |
+
"--save-raw-audio",
|
56 |
+
action="store_true",
|
57 |
+
help="Store generated audio as .wav instead of .npz",
|
58 |
+
)
|
59 |
+
parser.add_argument(
|
60 |
+
"--publish-hf",
|
61 |
+
action="store_true",
|
62 |
+
help="Publish dataset to HuggingFace Hub",
|
63 |
+
)
|
64 |
+
parser.add_argument(
|
65 |
+
"--repo-id",
|
66 |
+
type=str,
|
67 |
+
help="HuggingFace repo ID to publish to",
|
68 |
+
)
|
69 |
+
parser.add_argument(
|
70 |
+
"--path-in-repo",
|
71 |
+
type=str,
|
72 |
+
help="Path in HF repo",
|
73 |
+
default=None,
|
74 |
+
)
|
75 |
+
parser.add_argument(
|
76 |
+
"--silent", action="store_true", help="Suppress progress output"
|
77 |
+
)
|
78 |
+
|
79 |
+
return parser.parse_args(args_list)
|
80 |
+
|
81 |
+
|
82 |
+
def create_audio_semantic_dataset(
|
83 |
+
text_file: str,
|
84 |
+
output_dir: str = "./dataset",
|
85 |
+
batch_size: int = 1,
|
86 |
+
max_tokens: int = 256,
|
87 |
+
use_small_model: bool = False,
|
88 |
+
save_raw_audio: bool = False,
|
89 |
+
publish_hf: bool = False,
|
90 |
+
repo_id: Optional[str] = None,
|
91 |
+
path_in_repo: Optional[str] = None,
|
92 |
+
silent: bool = False,
|
93 |
+
) -> None:
|
94 |
+
"""Create audio semantic dataset from text file.
|
95 |
+
|
96 |
+
Can be called directly with parameters or via command line using parse_dataset_args().
|
97 |
+
|
98 |
+
Args:
|
99 |
+
text_file: Path to input text file
|
100 |
+
output_dir: Directory to save generated dataset
|
101 |
+
batch_size: Batch size for processing
|
102 |
+
max_tokens: Maximum tokens per example
|
103 |
+
use_small_model: Whether to use small model
|
104 |
+
save_raw_audio: Save as raw audio (.wav) instead of .npz
|
105 |
+
publish_hf: Whether to publish to HuggingFace Hub
|
106 |
+
repo_id: HF repo ID to publish to
|
107 |
+
path_in_repo: Path in HF repo
|
108 |
+
silent: Suppress progress output
|
109 |
+
"""
|
110 |
+
os.makedirs(output_dir, exist_ok=True)
|
111 |
+
|
112 |
+
if not os.path.isfile(text_file):
|
113 |
+
raise FileNotFoundError(f"Text file not found: {text_file}")
|
114 |
+
|
115 |
+
logger.info(f"Starting dataset generation from {text_file}")
|
116 |
+
generation_config = BarkGenerationConfig(
|
117 |
+
temperature=None,
|
118 |
+
generate_coarse_temperature=None,
|
119 |
+
generate_fine_temperature=None,
|
120 |
+
use_small_model=use_small_model,
|
121 |
+
)
|
122 |
+
|
123 |
+
generate_wav_semantic_dataset(
|
124 |
+
text_file_path=text_file,
|
125 |
+
generation_config=generation_config,
|
126 |
+
batch_size=batch_size,
|
127 |
+
save_path=output_dir,
|
128 |
+
save_data_as_raw_audio=save_raw_audio,
|
129 |
+
silent=silent,
|
130 |
+
)
|
131 |
+
logger.info("Dataset generation completed")
|
132 |
+
|
133 |
+
if publish_hf and repo_id:
|
134 |
+
logger.info("Publishing dataset to huggingface hub")
|
135 |
+
zip_path = "./dataset.zip"
|
136 |
+
success = zip_folder(output_dir, zip_path)
|
137 |
+
if not success:
|
138 |
+
raise RuntimeError(f"Unable to zip folder {output_dir}")
|
139 |
+
upload_file_to_hf(zip_path, repo_id, "dataset", path_in_repo=path_in_repo)
|
140 |
+
|
141 |
+
|
142 |
+
if __name__ == "__main__":
|
143 |
+
args = parse_dataset_args()
|
144 |
+
create_audio_semantic_dataset(
|
145 |
+
text_file=args.text_file,
|
146 |
+
output_dir=args.output_dir,
|
147 |
+
batch_size=args.batch_size,
|
148 |
+
max_tokens=args.max_tokens,
|
149 |
+
use_small_model=args.use_small_model,
|
150 |
+
save_raw_audio=args.save_raw_audio,
|
151 |
+
publish_hf=args.publish_hf,
|
152 |
+
repo_id=args.repo_id,
|
153 |
+
path_in_repo=args.path_in_repo,
|
154 |
+
silent=args.silent,
|
155 |
+
)
|
prompts/de_speaker_0.npz
ADDED
Binary file (39.6 kB). View file
|
|
prompts/de_speaker_1.npz
ADDED
Binary file (27.5 kB). View file
|
|
prompts/de_speaker_2.npz
ADDED
Binary file (24.7 kB). View file
|
|
prompts/de_speaker_3.npz
ADDED
Binary file (31.3 kB). View file
|
|
prompts/de_speaker_4.npz
ADDED
Binary file (30.7 kB). View file
|
|
prompts/de_speaker_5.npz
ADDED
Binary file (31.3 kB). View file
|
|
prompts/de_speaker_6.npz
ADDED
Binary file (23.2 kB). View file
|
|
prompts/de_speaker_7.npz
ADDED
Binary file (40.1 kB). View file
|
|
prompts/de_speaker_8.npz
ADDED
Binary file (28.5 kB). View file
|
|
prompts/de_speaker_9.npz
ADDED
Binary file (51.1 kB). View file
|
|
prompts/en_speaker_0.npz
ADDED
Binary file (28.1 kB). View file
|
|
prompts/en_speaker_1.npz
ADDED
Binary file (25.2 kB). View file
|
|
prompts/en_speaker_2.npz
ADDED
Binary file (26.2 kB). View file
|
|
prompts/en_speaker_3.npz
ADDED
Binary file (35 kB). View file
|
|
prompts/en_speaker_4.npz
ADDED
Binary file (23.8 kB). View file
|
|
prompts/en_speaker_5.npz
ADDED
Binary file (24.7 kB). View file
|
|