Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
a0e2cb7
0
Parent(s):
init
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +38 -0
- .gitignore +215 -0
- README.md +83 -0
- SongBloom/g2p/__init__.py +0 -0
- SongBloom/g2p/cn_zh_g2p/__init__.py +106 -0
- SongBloom/g2p/cn_zh_g2p/chinese.py +173 -0
- SongBloom/g2p/cn_zh_g2p/cmudict-fast.rep +0 -0
- SongBloom/g2p/cn_zh_g2p/cmudict.rep +0 -0
- SongBloom/g2p/cn_zh_g2p/engdict-hot.rep +2 -0
- SongBloom/g2p/cn_zh_g2p/engdict_cache.pickle +3 -0
- SongBloom/g2p/cn_zh_g2p/english.py +369 -0
- SongBloom/g2p/cn_zh_g2p/nltk_data/corpora/cmudict.zip +3 -0
- SongBloom/g2p/cn_zh_g2p/nltk_data/corpora/cmudict/README +76 -0
- SongBloom/g2p/cn_zh_g2p/nltk_data/corpora/cmudict/cmudict +0 -0
- SongBloom/g2p/cn_zh_g2p/nltk_data/taggers/averaged_perceptron_tagger.zip +3 -0
- SongBloom/g2p/cn_zh_g2p/nltk_data/taggers/averaged_perceptron_tagger/averaged_perceptron_tagger.pickle +3 -0
- SongBloom/g2p/cn_zh_g2p/opencpop-strict.txt +429 -0
- SongBloom/g2p/cn_zh_g2p/symbols.py +401 -0
- SongBloom/g2p/cn_zh_g2p/tone_sandhi.py +806 -0
- SongBloom/g2p/cn_zh_g2p/zh_normalization/README.md +16 -0
- SongBloom/g2p/cn_zh_g2p/zh_normalization/__init__.py +14 -0
- SongBloom/g2p/cn_zh_g2p/zh_normalization/char_convert.py +46 -0
- SongBloom/g2p/cn_zh_g2p/zh_normalization/chronology.py +134 -0
- SongBloom/g2p/cn_zh_g2p/zh_normalization/constants.py +62 -0
- SongBloom/g2p/cn_zh_g2p/zh_normalization/num.py +282 -0
- SongBloom/g2p/cn_zh_g2p/zh_normalization/phonecode.py +63 -0
- SongBloom/g2p/cn_zh_g2p/zh_normalization/quantifier.py +63 -0
- SongBloom/g2p/cn_zh_g2p/zh_normalization/text_normlization.py +165 -0
- SongBloom/g2p/lyric_common.py +81 -0
- SongBloom/g2p/pinyin/__init__.py +430 -0
- SongBloom/g2p/pinyin/pinyin.py +137 -0
- SongBloom/g2p/pinyin/symbols.py +71 -0
- SongBloom/models/base/sample.py +57 -0
- SongBloom/models/base/utils.py +57 -0
- SongBloom/models/musicgen/__init__.py +0 -0
- SongBloom/models/musicgen/conditioners/__init__.py +37 -0
- SongBloom/models/musicgen/conditioners/base.py +872 -0
- SongBloom/models/musicgen/conditioners/text.py +254 -0
- SongBloom/models/musicgen/conditioners/wav.py +74 -0
- SongBloom/models/musicgen/get_backend.py +76 -0
- SongBloom/models/musicgen/modules/streaming.py +125 -0
- SongBloom/models/musicldm/__init__.py +0 -0
- SongBloom/models/musicldm/inference/__init__.py +0 -0
- SongBloom/models/musicldm/inference/sampling.py +271 -0
- SongBloom/models/musicldm/musicldm_dit.py +24 -0
- SongBloom/models/songbloom/songbloom_mvsa.py +572 -0
- SongBloom/models/songbloom/songbloom_pl.py +224 -0
- SongBloom/models/transformer.py +937 -0
- SongBloom/models/vae_frontend/__init__.py +96 -0
- SongBloom/models/vae_frontend/autoencoders.py +657 -0
.gitattributes
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
*.flac filter=lfs diff=lfs merge=lfs -text
|
37 |
+
*.wav filter=lfs diff=lfs merge=lfs -text
|
38 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[codz]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
share/python-wheels/
|
24 |
+
*.egg-info/
|
25 |
+
.installed.cfg
|
26 |
+
*.egg
|
27 |
+
MANIFEST
|
28 |
+
|
29 |
+
# PyInstaller
|
30 |
+
# Usually these files are written by a python script from a template
|
31 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
32 |
+
*.manifest
|
33 |
+
*.spec
|
34 |
+
|
35 |
+
# Installer logs
|
36 |
+
pip-log.txt
|
37 |
+
pip-delete-this-directory.txt
|
38 |
+
|
39 |
+
# Unit test / coverage reports
|
40 |
+
htmlcov/
|
41 |
+
.tox/
|
42 |
+
.nox/
|
43 |
+
.coverage
|
44 |
+
.coverage.*
|
45 |
+
.cache
|
46 |
+
nosetests.xml
|
47 |
+
coverage.xml
|
48 |
+
*.cover
|
49 |
+
*.py.cover
|
50 |
+
.hypothesis/
|
51 |
+
.pytest_cache/
|
52 |
+
cover/
|
53 |
+
|
54 |
+
# Translations
|
55 |
+
*.mo
|
56 |
+
*.pot
|
57 |
+
|
58 |
+
# Django stuff:
|
59 |
+
*.log
|
60 |
+
local_settings.py
|
61 |
+
db.sqlite3
|
62 |
+
db.sqlite3-journal
|
63 |
+
|
64 |
+
# Flask stuff:
|
65 |
+
instance/
|
66 |
+
.webassets-cache
|
67 |
+
|
68 |
+
# Scrapy stuff:
|
69 |
+
.scrapy
|
70 |
+
|
71 |
+
# Sphinx documentation
|
72 |
+
docs/_build/
|
73 |
+
|
74 |
+
# PyBuilder
|
75 |
+
.pybuilder/
|
76 |
+
target/
|
77 |
+
|
78 |
+
# Jupyter Notebook
|
79 |
+
.ipynb_checkpoints
|
80 |
+
|
81 |
+
# IPython
|
82 |
+
profile_default/
|
83 |
+
ipython_config.py
|
84 |
+
|
85 |
+
# pyenv
|
86 |
+
# For a library or package, you might want to ignore these files since the code is
|
87 |
+
# intended to run in multiple environments; otherwise, check them in:
|
88 |
+
# .python-version
|
89 |
+
|
90 |
+
# pipenv
|
91 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
92 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
93 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
94 |
+
# install all needed dependencies.
|
95 |
+
#Pipfile.lock
|
96 |
+
|
97 |
+
# UV
|
98 |
+
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
|
99 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
100 |
+
# commonly ignored for libraries.
|
101 |
+
#uv.lock
|
102 |
+
|
103 |
+
# poetry
|
104 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
105 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
106 |
+
# commonly ignored for libraries.
|
107 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
108 |
+
#poetry.lock
|
109 |
+
#poetry.toml
|
110 |
+
|
111 |
+
# pdm
|
112 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
113 |
+
# pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python.
|
114 |
+
# https://pdm-project.org/en/latest/usage/project/#working-with-version-control
|
115 |
+
#pdm.lock
|
116 |
+
#pdm.toml
|
117 |
+
.pdm-python
|
118 |
+
.pdm-build/
|
119 |
+
|
120 |
+
# pixi
|
121 |
+
# Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control.
|
122 |
+
#pixi.lock
|
123 |
+
# Pixi creates a virtual environment in the .pixi directory, just like venv module creates one
|
124 |
+
# in the .venv directory. It is recommended not to include this directory in version control.
|
125 |
+
.pixi
|
126 |
+
|
127 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
128 |
+
__pypackages__/
|
129 |
+
|
130 |
+
# Celery stuff
|
131 |
+
celerybeat-schedule
|
132 |
+
celerybeat.pid
|
133 |
+
|
134 |
+
# SageMath parsed files
|
135 |
+
*.sage.py
|
136 |
+
|
137 |
+
# Environments
|
138 |
+
.env
|
139 |
+
.envrc
|
140 |
+
.venv
|
141 |
+
env/
|
142 |
+
venv/
|
143 |
+
ENV/
|
144 |
+
env.bak/
|
145 |
+
venv.bak/
|
146 |
+
|
147 |
+
# Spyder project settings
|
148 |
+
.spyderproject
|
149 |
+
.spyproject
|
150 |
+
|
151 |
+
# Rope project settings
|
152 |
+
.ropeproject
|
153 |
+
|
154 |
+
# mkdocs documentation
|
155 |
+
/site
|
156 |
+
|
157 |
+
# mypy
|
158 |
+
.mypy_cache/
|
159 |
+
.dmypy.json
|
160 |
+
dmypy.json
|
161 |
+
|
162 |
+
# Pyre type checker
|
163 |
+
.pyre/
|
164 |
+
|
165 |
+
# pytype static type analyzer
|
166 |
+
.pytype/
|
167 |
+
|
168 |
+
# Cython debug symbols
|
169 |
+
cython_debug/
|
170 |
+
|
171 |
+
# PyCharm
|
172 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
173 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
174 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
175 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
176 |
+
#.idea/
|
177 |
+
|
178 |
+
# Abstra
|
179 |
+
# Abstra is an AI-powered process automation framework.
|
180 |
+
# Ignore directories containing user credentials, local state, and settings.
|
181 |
+
# Learn more at https://abstra.io/docs
|
182 |
+
.abstra/
|
183 |
+
|
184 |
+
# Visual Studio Code
|
185 |
+
# Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore
|
186 |
+
# that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore
|
187 |
+
# and can be added to the global gitignore or merged into this file. However, if you prefer,
|
188 |
+
# you could uncomment the following to ignore the entire vscode folder
|
189 |
+
# .vscode/
|
190 |
+
|
191 |
+
# Ruff stuff:
|
192 |
+
.ruff_cache/
|
193 |
+
|
194 |
+
# PyPI configuration file
|
195 |
+
.pypirc
|
196 |
+
|
197 |
+
# Cursor
|
198 |
+
# Cursor is an AI-powered code editor. `.cursorignore` specifies files/directories to
|
199 |
+
# exclude from AI features like autocomplete and code analysis. Recommended for sensitive data
|
200 |
+
# refer to https://docs.cursor.com/context/ignore-files
|
201 |
+
.cursorignore
|
202 |
+
.cursorindexingignore
|
203 |
+
|
204 |
+
# Marimo
|
205 |
+
marimo/_static/
|
206 |
+
marimo/_lsp/
|
207 |
+
__marimo__/
|
208 |
+
|
209 |
+
# Streamlit
|
210 |
+
.streamlit/secrets.toml
|
211 |
+
|
212 |
+
__pycache__
|
213 |
+
output
|
214 |
+
cache
|
215 |
+
third_party
|
README.md
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# [SongBloom]: *Coherent Song Generation via Interleaved Autoregressive Sketching and Diffusion Refinement*
|
2 |
+
|
3 |
+
We propose **SongBloom**, a novel framework for full-length song generation that leverages an interleaved paradigm of autoregressive sketching and diffusion-based refinement. SongBloom employs an autoregressive diffusion model that combines the high fidelity of diffusion models with the scalability of language models.
|
4 |
+
Specifically, it gradually extends a musical sketch from short to long and refines the details from coarse to fine-grained. The interleaved generation paradigm effectively integrates prior semantic and acoustic context to guide the generation process.
|
5 |
+
Experimental results demonstrate that SongBloom outperforms existing methods across both subjective and objective metrics and achieves performance comparable to the state-of-the-art commercial music generation platforms.
|
6 |
+
|
7 |
+

|
8 |
+
|
9 |
+
Demo page: [https://cypress-yang.github.io/SongBloom_demo](https://cypress-yang.github.io/SongBloom_demo)
|
10 |
+
|
11 |
+
ArXiv: [https://arxiv.org/abs/2506.07634](https://arxiv.org/abs/2506.07634)
|
12 |
+
|
13 |
+
## Prepare Environments
|
14 |
+
|
15 |
+
```bash
|
16 |
+
conda create -n SongBloom python==3.8.12
|
17 |
+
conda activate SongBloom
|
18 |
+
|
19 |
+
# yum install libsndfile
|
20 |
+
# pip install torch==2.2.0 torchaudio==2.2.0 --index-url https://download.pytorch.org/whl/cu118 # For different CUDA version
|
21 |
+
pip install -r requirements.txt
|
22 |
+
```
|
23 |
+
|
24 |
+
## Data Preparation
|
25 |
+
|
26 |
+
A .jsonl file, where each line is a json object:
|
27 |
+
|
28 |
+
```json
|
29 |
+
{
|
30 |
+
"idx": "The index of each sample",
|
31 |
+
"lyrics": "The lyrics to be generated",
|
32 |
+
"prompt_wav": "The path of the style prompt audio",
|
33 |
+
}
|
34 |
+
```
|
35 |
+
|
36 |
+
One example can be refered to as: [example/test.jsonl](example/test.jsonl)
|
37 |
+
|
38 |
+
The prompt wav should be a 10-second, 48kHz audio clip.
|
39 |
+
|
40 |
+
The details about lyric format can be found in [docs/lyric_format.md](docs/lyric_format.md).
|
41 |
+
|
42 |
+
## Inference
|
43 |
+
|
44 |
+
```bash
|
45 |
+
source set_env.sh
|
46 |
+
|
47 |
+
python3 infer.py --input-jsonl example/test.jsonl
|
48 |
+
|
49 |
+
# For GPUs with low VRAM like RTX4090, you should set the dtype as bfloat16
|
50 |
+
python3 infer.py --input-jsonl example/test.jsonl --dtype bfloat16
|
51 |
+
|
52 |
+
# SongBloom also supports flash-attn (optional). To enable it, please install flash-attn (v2.6.3 is used during training) manually and set os.environ['DISABLE_FLASH_ATTN'] = "0" in infer.py:8
|
53 |
+
```
|
54 |
+
|
55 |
+
## Models
|
56 |
+
|
57 |
+
| Name | Size | Max Length | Prompt type | 🤗 |
|
58 |
+
| -------------------- | ---- | ---------- | ----------- | -------------------------------------------- |
|
59 |
+
| songbloom_full_150s | 2B | 2m30s | 10s wav | [link](https://huggingface.co/CypressYang/SongBloom) |
|
60 |
+
| songbloom_mulan_150s | 2B | 2m30s | 10s wav / text description | coming soon |
|
61 |
+
| ... | | | | |
|
62 |
+
|
63 |
+
|
64 |
+
|
65 |
+
## TODO List
|
66 |
+
|
67 |
+
- [ ] Support Text Description
|
68 |
+
- [ ] Full version
|
69 |
+
|
70 |
+
## Citation
|
71 |
+
|
72 |
+
```
|
73 |
+
@article{yang2025songbloom,
|
74 |
+
title={SongBloom: Coherent Song Generation via Interleaved Autoregressive Sketching and Diffusion Refinement},
|
75 |
+
author={Yang, Chenyu and Wang, Shuai and Chen, Hangting and Tan, Wei and Yu, Jianwei and Li, Haizhou},
|
76 |
+
journal={arXiv preprint arXiv:2506.07634},
|
77 |
+
year={2025}
|
78 |
+
}
|
79 |
+
```
|
80 |
+
|
81 |
+
## License
|
82 |
+
|
83 |
+
SongBloom (codes and weights) is released under the [Apache License 2.0](https://www.apache.org/licenses/LICENSE-2.0).
|
SongBloom/g2p/__init__.py
ADDED
File without changes
|
SongBloom/g2p/cn_zh_g2p/__init__.py
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from . import chinese, english # , japanese 暂时干掉看看
|
2 |
+
from .symbols import *
|
3 |
+
import yaml
|
4 |
+
language_module_map = {"zh": chinese, "en": english} #, "ja": japanese
|
5 |
+
|
6 |
+
def is_chinese(uchar):
|
7 |
+
if uchar >= u'\u4e00' and uchar <= u'\u9fa5':
|
8 |
+
return True
|
9 |
+
else:
|
10 |
+
return False
|
11 |
+
|
12 |
+
import re
|
13 |
+
|
14 |
+
# def split_text(text):
|
15 |
+
# chinese_pattern = r'[\u4e00-\u9fa5][\u4e00-\u9fa5\ \,\.\!\?\,\。]+'
|
16 |
+
# english_pattern = r'[a-zA-Z][a-zA-Z\'\ \,\.\!\?]+'
|
17 |
+
|
18 |
+
# chinese_text = re.findall(chinese_pattern, text)
|
19 |
+
# print(chinese_text)
|
20 |
+
# english_text = re.findall(english_pattern, text)
|
21 |
+
|
22 |
+
# return chinese_text, english_text
|
23 |
+
|
24 |
+
def split_text(text):
|
25 |
+
pattern = re.compile("|".join(re.escape(p) for p in chinese.rep_map.keys()))
|
26 |
+
text = pattern.sub(lambda x: chinese.rep_map[x.group()], text)
|
27 |
+
|
28 |
+
result = []
|
29 |
+
lang = []
|
30 |
+
buffer = ""
|
31 |
+
chinese_pattern = r'[\u4e00-\u9fa5]'
|
32 |
+
special_pattern = r'[\,\.\!\?\…\-]'
|
33 |
+
# TODO check 一下
|
34 |
+
for char in text:
|
35 |
+
if re.match(special_pattern, char):
|
36 |
+
if buffer:
|
37 |
+
if not re.match(chinese_pattern, buffer[0]):
|
38 |
+
result.append(buffer)
|
39 |
+
lang.append('en')
|
40 |
+
else:
|
41 |
+
result.append(buffer)
|
42 |
+
lang.append("zh")
|
43 |
+
result.append(char)
|
44 |
+
lang.append('sp')
|
45 |
+
buffer = ""
|
46 |
+
|
47 |
+
|
48 |
+
elif re.match(chinese_pattern, char):
|
49 |
+
if buffer and not re.match(chinese_pattern, buffer[-1]):
|
50 |
+
result.append(buffer)
|
51 |
+
buffer = ""
|
52 |
+
lang.append('en')
|
53 |
+
buffer += char
|
54 |
+
else:
|
55 |
+
if buffer and re.match(chinese_pattern, buffer[-1]):
|
56 |
+
result.append(buffer)
|
57 |
+
buffer = ""
|
58 |
+
lang.append("zh")
|
59 |
+
buffer += char
|
60 |
+
|
61 |
+
if buffer:
|
62 |
+
result.append(buffer)
|
63 |
+
lang.append("zh" if re.match(chinese_pattern, buffer[-1]) else 'en')
|
64 |
+
|
65 |
+
return result, lang
|
66 |
+
|
67 |
+
def mixed_language_to_phoneme(text):
|
68 |
+
segments, lang = split_text(text)
|
69 |
+
# print(segments, lang)
|
70 |
+
result = [language_to_phoneme(s, l) for s, l in zip(segments, lang)]
|
71 |
+
phones, word2ph = [], []
|
72 |
+
for p, w, n in result:
|
73 |
+
phones += p
|
74 |
+
if w is None:
|
75 |
+
w = []
|
76 |
+
word2ph += w
|
77 |
+
return phones, word2ph
|
78 |
+
|
79 |
+
|
80 |
+
def language_to_phoneme(text, language):
|
81 |
+
if language == 'sp':
|
82 |
+
return [text], None, text
|
83 |
+
language_module = language_module_map[language]
|
84 |
+
norm_text = language_module.text_normalize(text)
|
85 |
+
if language == "zh":
|
86 |
+
phones, word2ph = language_module.g2p(norm_text)
|
87 |
+
assert len(phones) == sum(word2ph)
|
88 |
+
assert len(norm_text) == len(word2ph)
|
89 |
+
else:
|
90 |
+
try:
|
91 |
+
phones = language_module.g2p(norm_text)
|
92 |
+
except:
|
93 |
+
phones = [norm_text]
|
94 |
+
word2ph = None
|
95 |
+
|
96 |
+
# for ph in phones:
|
97 |
+
# assert ph in symbols, ph
|
98 |
+
return phones, word2ph, norm_text
|
99 |
+
|
100 |
+
def gen_vocabs():
|
101 |
+
yaml.dump(symbols, open('./vocab.yaml', 'w'))
|
102 |
+
|
103 |
+
class G2P_Mix():
|
104 |
+
def __call__(self, text):
|
105 |
+
phones, word2ph = mixed_language_to_phoneme(text)
|
106 |
+
return ' '.join(phones)
|
SongBloom/g2p/cn_zh_g2p/chinese.py
ADDED
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import pdb
|
3 |
+
import re
|
4 |
+
|
5 |
+
import cn2an
|
6 |
+
from pypinyin import lazy_pinyin, Style
|
7 |
+
|
8 |
+
from .symbols import punctuation
|
9 |
+
from .tone_sandhi import ToneSandhi
|
10 |
+
from .zh_normalization.text_normlization import TextNormalizer
|
11 |
+
|
12 |
+
normalizer = lambda x: cn2an.transform(x, "an2cn")
|
13 |
+
|
14 |
+
current_file_path = os.path.dirname(__file__)
|
15 |
+
pinyin_to_symbol_map = {
|
16 |
+
line.split("\t")[0]: line.strip().split("\t")[1]
|
17 |
+
for line in open(os.path.join(current_file_path, "opencpop-strict.txt")).readlines()
|
18 |
+
}
|
19 |
+
|
20 |
+
import jieba_fast.posseg as psg
|
21 |
+
|
22 |
+
|
23 |
+
rep_map = {
|
24 |
+
":": ",",
|
25 |
+
";": ",",
|
26 |
+
",": ",",
|
27 |
+
"。": ".",
|
28 |
+
"!": "!",
|
29 |
+
"?": "?",
|
30 |
+
"\n": ".",
|
31 |
+
"·": ",",
|
32 |
+
"、": ",",
|
33 |
+
"...": "…",
|
34 |
+
"$": ".",
|
35 |
+
"/": ",",
|
36 |
+
"—": "-",
|
37 |
+
"~": "…",
|
38 |
+
"~":"…",
|
39 |
+
}
|
40 |
+
|
41 |
+
tone_modifier = ToneSandhi()
|
42 |
+
|
43 |
+
|
44 |
+
def replace_punctuation(text):
|
45 |
+
text = text.replace("嗯", "恩").replace("呣", "母")
|
46 |
+
pattern = re.compile("|".join(re.escape(p) for p in rep_map.keys()))
|
47 |
+
|
48 |
+
replaced_text = pattern.sub(lambda x: rep_map[x.group()], text)
|
49 |
+
replaced_text = re.sub(
|
50 |
+
r"[^\u4e00-\u9fa5" + "".join(punctuation) + r"]+", "", replaced_text
|
51 |
+
)
|
52 |
+
|
53 |
+
return replaced_text
|
54 |
+
|
55 |
+
|
56 |
+
def g2p(text):
|
57 |
+
pattern = r"(?<=[{0}])\s*".format("".join(punctuation))
|
58 |
+
sentences = [i for i in re.split(pattern, text) if i.strip() != ""]
|
59 |
+
phones, word2ph = _g2p(sentences)
|
60 |
+
return phones, word2ph
|
61 |
+
|
62 |
+
|
63 |
+
def _get_initials_finals(word):
|
64 |
+
initials = []
|
65 |
+
finals = []
|
66 |
+
orig_initials = lazy_pinyin(word, neutral_tone_with_five=True, style=Style.INITIALS)
|
67 |
+
orig_finals = lazy_pinyin(
|
68 |
+
word, neutral_tone_with_five=True, style=Style.FINALS_TONE3
|
69 |
+
)
|
70 |
+
for c, v in zip(orig_initials, orig_finals):
|
71 |
+
initials.append(c)
|
72 |
+
finals.append(v)
|
73 |
+
return initials, finals
|
74 |
+
|
75 |
+
|
76 |
+
def _g2p(segments):
|
77 |
+
phones_list = []
|
78 |
+
word2ph = []
|
79 |
+
for seg in segments:
|
80 |
+
pinyins = []
|
81 |
+
# Replace all English words in the sentence
|
82 |
+
seg = re.sub("[a-zA-Z]+", "", seg)
|
83 |
+
seg_cut = psg.lcut(seg)
|
84 |
+
initials = []
|
85 |
+
finals = []
|
86 |
+
seg_cut = tone_modifier.pre_merge_for_modify(seg_cut)
|
87 |
+
for word, pos in seg_cut:
|
88 |
+
if pos == "eng":
|
89 |
+
continue
|
90 |
+
sub_initials, sub_finals = _get_initials_finals(word)
|
91 |
+
sub_finals = tone_modifier.modified_tone(word, pos, sub_finals)
|
92 |
+
initials.append(sub_initials)
|
93 |
+
finals.append(sub_finals)
|
94 |
+
|
95 |
+
# assert len(sub_initials) == len(sub_finals) == len(word)
|
96 |
+
initials = sum(initials, [])
|
97 |
+
finals = sum(finals, [])
|
98 |
+
#
|
99 |
+
for c, v in zip(initials, finals):
|
100 |
+
raw_pinyin = c + v
|
101 |
+
# NOTE: post process for pypinyin outputs
|
102 |
+
# we discriminate i, ii and iii
|
103 |
+
if c == v:
|
104 |
+
assert c in punctuation
|
105 |
+
phone = [c]
|
106 |
+
word2ph.append(1)
|
107 |
+
else:
|
108 |
+
v_without_tone = v[:-1]
|
109 |
+
tone = v[-1]
|
110 |
+
|
111 |
+
pinyin = c + v_without_tone
|
112 |
+
assert tone in "12345"
|
113 |
+
|
114 |
+
if c:
|
115 |
+
# 多音节
|
116 |
+
v_rep_map = {
|
117 |
+
"uei": "ui",
|
118 |
+
"iou": "iu",
|
119 |
+
"uen": "un",
|
120 |
+
}
|
121 |
+
if v_without_tone in v_rep_map.keys():
|
122 |
+
pinyin = c + v_rep_map[v_without_tone]
|
123 |
+
else:
|
124 |
+
# 单音节
|
125 |
+
pinyin_rep_map = {
|
126 |
+
"ing": "ying",
|
127 |
+
"i": "yi",
|
128 |
+
"in": "yin",
|
129 |
+
"u": "wu",
|
130 |
+
}
|
131 |
+
if pinyin in pinyin_rep_map.keys():
|
132 |
+
pinyin = pinyin_rep_map[pinyin]
|
133 |
+
else:
|
134 |
+
single_rep_map = {
|
135 |
+
"v": "yu",
|
136 |
+
"e": "e",
|
137 |
+
"i": "y",
|
138 |
+
"u": "w",
|
139 |
+
}
|
140 |
+
if pinyin[0] in single_rep_map.keys():
|
141 |
+
pinyin = single_rep_map[pinyin[0]] + pinyin[1:]
|
142 |
+
|
143 |
+
assert pinyin in pinyin_to_symbol_map.keys(), (pinyin, seg, raw_pinyin)
|
144 |
+
new_c, new_v = pinyin_to_symbol_map[pinyin].split(" ")
|
145 |
+
new_v = new_v + tone
|
146 |
+
phone = [new_c, new_v]
|
147 |
+
word2ph.append(len(phone))
|
148 |
+
|
149 |
+
phones_list += phone
|
150 |
+
return phones_list, word2ph
|
151 |
+
|
152 |
+
|
153 |
+
def text_normalize(text):
|
154 |
+
# https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/paddlespeech/t2s/frontend/zh_normalization
|
155 |
+
tx = TextNormalizer()
|
156 |
+
sentences = tx.normalize(text)
|
157 |
+
dest_text = ""
|
158 |
+
for sentence in sentences:
|
159 |
+
dest_text += replace_punctuation(sentence)
|
160 |
+
return dest_text
|
161 |
+
|
162 |
+
|
163 |
+
if __name__ == "__main__":
|
164 |
+
text = "啊——但是《原神》是由,米哈\游自主,研发的一款全.新开放世界.冒险游戏"
|
165 |
+
text = "呣呣呣~就是…大人的鼹鼠党吧?"
|
166 |
+
text = "你好"
|
167 |
+
text = text_normalize(text)
|
168 |
+
print(g2p(text))
|
169 |
+
|
170 |
+
|
171 |
+
# # 示例用法
|
172 |
+
# text = "这是一个示例文本:,你好!这是一个测试..."
|
173 |
+
# print(g2p_paddle(text)) # 输出: 这是一个示例文本你好这是一个测试
|
SongBloom/g2p/cn_zh_g2p/cmudict-fast.rep
ADDED
The diff for this file is too large to render.
See raw diff
|
|
SongBloom/g2p/cn_zh_g2p/cmudict.rep
ADDED
The diff for this file is too large to render.
See raw diff
|
|
SongBloom/g2p/cn_zh_g2p/engdict-hot.rep
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
CHATGPT CH AE1 T JH IY1 P IY1 T IY1
|
2 |
+
JSON JH EY1 S AH0 N
|
SongBloom/g2p/cn_zh_g2p/engdict_cache.pickle
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9bff9393f4b192d873a11335efc8f124771087b6dc847d34fd240c2846889d2b
|
3 |
+
size 5965909
|
SongBloom/g2p/cn_zh_g2p/english.py
ADDED
@@ -0,0 +1,369 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pickle
|
2 |
+
import os
|
3 |
+
import re
|
4 |
+
import wordsegment
|
5 |
+
from g2p_en import G2p
|
6 |
+
|
7 |
+
from string import punctuation
|
8 |
+
|
9 |
+
from .symbols import symbols
|
10 |
+
|
11 |
+
import unicodedata
|
12 |
+
from builtins import str as unicode
|
13 |
+
from g2p_en.expand import normalize_numbers
|
14 |
+
|
15 |
+
# Set NLTK data path programmatically to avoid needing set_env.sh
|
16 |
+
import nltk
|
17 |
+
current_file_path = os.path.dirname(__file__)
|
18 |
+
nltk_data_path = os.path.join(current_file_path, "nltk_data")
|
19 |
+
if os.path.exists(nltk_data_path):
|
20 |
+
nltk.data.path.insert(0, nltk_data_path)
|
21 |
+
|
22 |
+
from nltk.tokenize import TweetTokenizer
|
23 |
+
word_tokenize = TweetTokenizer().tokenize
|
24 |
+
from nltk import pos_tag
|
25 |
+
|
26 |
+
CMU_DICT_PATH = os.path.join(current_file_path, "cmudict.rep")
|
27 |
+
CMU_DICT_FAST_PATH = os.path.join(current_file_path, "cmudict-fast.rep")
|
28 |
+
CMU_DICT_HOT_PATH = os.path.join(current_file_path, "engdict-hot.rep")
|
29 |
+
CACHE_PATH = os.path.join(current_file_path, "engdict_cache.pickle")
|
30 |
+
NAMECACHE_PATH = os.path.join(current_file_path, "namedict_cache.pickle")
|
31 |
+
|
32 |
+
arpa = {
|
33 |
+
"AH0",
|
34 |
+
"S",
|
35 |
+
"AH1",
|
36 |
+
"EY2",
|
37 |
+
"AE2",
|
38 |
+
"EH0",
|
39 |
+
"OW2",
|
40 |
+
"UH0",
|
41 |
+
"NG",
|
42 |
+
"B",
|
43 |
+
"G",
|
44 |
+
"AY0",
|
45 |
+
"M",
|
46 |
+
"AA0",
|
47 |
+
"F",
|
48 |
+
"AO0",
|
49 |
+
"ER2",
|
50 |
+
"UH1",
|
51 |
+
"IY1",
|
52 |
+
"AH2",
|
53 |
+
"DH",
|
54 |
+
"IY0",
|
55 |
+
"EY1",
|
56 |
+
"IH0",
|
57 |
+
"K",
|
58 |
+
"N",
|
59 |
+
"W",
|
60 |
+
"IY2",
|
61 |
+
"T",
|
62 |
+
"AA1",
|
63 |
+
"ER1",
|
64 |
+
"EH2",
|
65 |
+
"OY0",
|
66 |
+
"UH2",
|
67 |
+
"UW1",
|
68 |
+
"Z",
|
69 |
+
"AW2",
|
70 |
+
"AW1",
|
71 |
+
"V",
|
72 |
+
"UW2",
|
73 |
+
"AA2",
|
74 |
+
"ER",
|
75 |
+
"AW0",
|
76 |
+
"UW0",
|
77 |
+
"R",
|
78 |
+
"OW1",
|
79 |
+
"EH1",
|
80 |
+
"ZH",
|
81 |
+
"AE0",
|
82 |
+
"IH2",
|
83 |
+
"IH",
|
84 |
+
"Y",
|
85 |
+
"JH",
|
86 |
+
"P",
|
87 |
+
"AY1",
|
88 |
+
"EY0",
|
89 |
+
"OY2",
|
90 |
+
"TH",
|
91 |
+
"HH",
|
92 |
+
"D",
|
93 |
+
"ER0",
|
94 |
+
"CH",
|
95 |
+
"AO1",
|
96 |
+
"AE1",
|
97 |
+
"AO2",
|
98 |
+
"OY1",
|
99 |
+
"AY2",
|
100 |
+
"IH1",
|
101 |
+
"OW0",
|
102 |
+
"L",
|
103 |
+
"SH",
|
104 |
+
}
|
105 |
+
|
106 |
+
|
107 |
+
def replace_phs(phs):
|
108 |
+
rep_map = {"'": "-"}
|
109 |
+
phs_new = []
|
110 |
+
for ph in phs:
|
111 |
+
if ph in symbols:
|
112 |
+
phs_new.append(ph)
|
113 |
+
elif ph in rep_map.keys():
|
114 |
+
phs_new.append(rep_map[ph])
|
115 |
+
else:
|
116 |
+
print("ph not in symbols: ", ph)
|
117 |
+
return phs_new
|
118 |
+
|
119 |
+
|
120 |
+
def read_dict():
|
121 |
+
g2p_dict = {}
|
122 |
+
start_line = 49
|
123 |
+
with open(CMU_DICT_PATH) as f:
|
124 |
+
line = f.readline()
|
125 |
+
line_index = 1
|
126 |
+
while line:
|
127 |
+
if line_index >= start_line:
|
128 |
+
line = line.strip()
|
129 |
+
word_split = line.split(" ")
|
130 |
+
word = word_split[0].lower()
|
131 |
+
|
132 |
+
syllable_split = word_split[1].split(" - ")
|
133 |
+
g2p_dict[word] = []
|
134 |
+
for syllable in syllable_split:
|
135 |
+
phone_split = syllable.split(" ")
|
136 |
+
g2p_dict[word].append(phone_split)
|
137 |
+
|
138 |
+
line_index = line_index + 1
|
139 |
+
line = f.readline()
|
140 |
+
|
141 |
+
return g2p_dict
|
142 |
+
|
143 |
+
|
144 |
+
def read_dict_new():
|
145 |
+
g2p_dict = {}
|
146 |
+
with open(CMU_DICT_PATH) as f:
|
147 |
+
line = f.readline()
|
148 |
+
line_index = 1
|
149 |
+
while line:
|
150 |
+
if line_index >= 57:
|
151 |
+
line = line.strip()
|
152 |
+
word_split = line.split(" ")
|
153 |
+
word = word_split[0].lower()
|
154 |
+
g2p_dict[word] = [word_split[1].split(" ")]
|
155 |
+
|
156 |
+
line_index = line_index + 1
|
157 |
+
line = f.readline()
|
158 |
+
|
159 |
+
with open(CMU_DICT_FAST_PATH) as f:
|
160 |
+
line = f.readline()
|
161 |
+
line_index = 1
|
162 |
+
while line:
|
163 |
+
if line_index >= 0:
|
164 |
+
line = line.strip()
|
165 |
+
word_split = line.split(" ")
|
166 |
+
word = word_split[0].lower()
|
167 |
+
if word not in g2p_dict:
|
168 |
+
g2p_dict[word] = [word_split[1:]]
|
169 |
+
|
170 |
+
line_index = line_index + 1
|
171 |
+
line = f.readline()
|
172 |
+
|
173 |
+
return g2p_dict
|
174 |
+
|
175 |
+
def hot_reload_hot(g2p_dict):
|
176 |
+
with open(CMU_DICT_HOT_PATH) as f:
|
177 |
+
line = f.readline()
|
178 |
+
line_index = 1
|
179 |
+
while line:
|
180 |
+
if line_index >= 0:
|
181 |
+
line = line.strip()
|
182 |
+
word_split = line.split(" ")
|
183 |
+
word = word_split[0].lower()
|
184 |
+
# 自定义发音词直接覆盖字典
|
185 |
+
g2p_dict[word] = [word_split[1:]]
|
186 |
+
|
187 |
+
line_index = line_index + 1
|
188 |
+
line = f.readline()
|
189 |
+
|
190 |
+
return g2p_dict
|
191 |
+
|
192 |
+
|
193 |
+
def cache_dict(g2p_dict, file_path):
|
194 |
+
with open(file_path, "wb") as pickle_file:
|
195 |
+
pickle.dump(g2p_dict, pickle_file)
|
196 |
+
|
197 |
+
|
198 |
+
def get_dict():
|
199 |
+
if os.path.exists(CACHE_PATH):
|
200 |
+
with open(CACHE_PATH, "rb") as pickle_file:
|
201 |
+
g2p_dict = pickle.load(pickle_file)
|
202 |
+
else:
|
203 |
+
g2p_dict = read_dict_new()
|
204 |
+
cache_dict(g2p_dict, CACHE_PATH)
|
205 |
+
|
206 |
+
g2p_dict = hot_reload_hot(g2p_dict)
|
207 |
+
|
208 |
+
return g2p_dict
|
209 |
+
|
210 |
+
|
211 |
+
def get_namedict():
|
212 |
+
if os.path.exists(NAMECACHE_PATH):
|
213 |
+
with open(NAMECACHE_PATH, "rb") as pickle_file:
|
214 |
+
name_dict = pickle.load(pickle_file)
|
215 |
+
else:
|
216 |
+
name_dict = {}
|
217 |
+
|
218 |
+
return name_dict
|
219 |
+
|
220 |
+
|
221 |
+
def text_normalize(text):
|
222 |
+
# todo: eng text normalize
|
223 |
+
# 适配中文及 g2p_en 标点
|
224 |
+
rep_map = {
|
225 |
+
"[;::,;]": ",",
|
226 |
+
'["’]': "'",
|
227 |
+
"。": ".",
|
228 |
+
"!": "!",
|
229 |
+
"?": "?",
|
230 |
+
}
|
231 |
+
for p, r in rep_map.items():
|
232 |
+
text = re.sub(p, r, text)
|
233 |
+
|
234 |
+
# 来自 g2p_en 文本格式化处理
|
235 |
+
# 增加大写兼容
|
236 |
+
text = unicode(text)
|
237 |
+
text = normalize_numbers(text)
|
238 |
+
text = ''.join(char for char in unicodedata.normalize('NFD', text)
|
239 |
+
if unicodedata.category(char) != 'Mn') # Strip accents
|
240 |
+
text = re.sub("[^ A-Za-z'.,?!\-]", "", text)
|
241 |
+
text = re.sub(r"(?i)i\.e\.", "that is", text)
|
242 |
+
text = re.sub(r"(?i)e\.g\.", "for example", text)
|
243 |
+
|
244 |
+
return text
|
245 |
+
|
246 |
+
|
247 |
+
class en_G2p(G2p):
|
248 |
+
def __init__(self):
|
249 |
+
super().__init__()
|
250 |
+
# 分词初始化
|
251 |
+
wordsegment.load()
|
252 |
+
|
253 |
+
# 扩展过时字典, 添加姓名字典
|
254 |
+
self.cmu = get_dict()
|
255 |
+
self.namedict = get_namedict()
|
256 |
+
|
257 |
+
# 剔除读音错误的几个缩写
|
258 |
+
for word in ["AE", "AI", "AR", "IOS", "HUD", "OS"]:
|
259 |
+
del self.cmu[word.lower()]
|
260 |
+
|
261 |
+
# 修正多音字
|
262 |
+
self.homograph2features["read"] = (['R', 'IY1', 'D'], ['R', 'EH1', 'D'], 'VBP')
|
263 |
+
self.homograph2features["complex"] = (['K', 'AH0', 'M', 'P', 'L', 'EH1', 'K', 'S'], ['K', 'AA1', 'M', 'P', 'L', 'EH0', 'K', 'S'], 'JJ')
|
264 |
+
|
265 |
+
|
266 |
+
def __call__(self, text):
|
267 |
+
# tokenization
|
268 |
+
words = word_tokenize(text)
|
269 |
+
tokens = pos_tag(words) # tuples of (word, tag)
|
270 |
+
|
271 |
+
# steps
|
272 |
+
prons = []
|
273 |
+
for o_word, pos in tokens:
|
274 |
+
# 还原 g2p_en 小写操作逻辑
|
275 |
+
word = o_word.lower()
|
276 |
+
|
277 |
+
if re.search("[a-z]", word) is None:
|
278 |
+
pron = [word]
|
279 |
+
# 先把单字母推出去
|
280 |
+
elif len(word) == 1:
|
281 |
+
# 单读 A 发音修正, 这里需要原格式 o_word 判断大写
|
282 |
+
if o_word == "A":
|
283 |
+
pron = ['EY1']
|
284 |
+
else:
|
285 |
+
pron = self.cmu[word][0]
|
286 |
+
# g2p_en 原版多音字处理
|
287 |
+
elif word in self.homograph2features: # Check homograph
|
288 |
+
pron1, pron2, pos1 = self.homograph2features[word]
|
289 |
+
if pos.startswith(pos1):
|
290 |
+
pron = pron1
|
291 |
+
# pos1比pos长仅出现在read
|
292 |
+
elif len(pos) < len(pos1) and pos == pos1[:len(pos)]:
|
293 |
+
pron = pron1
|
294 |
+
else:
|
295 |
+
pron = pron2
|
296 |
+
else:
|
297 |
+
# 递归查找预测
|
298 |
+
pron = self.qryword(o_word)
|
299 |
+
|
300 |
+
prons.extend(pron)
|
301 |
+
prons.extend([" "])
|
302 |
+
|
303 |
+
return prons[:-1]
|
304 |
+
|
305 |
+
|
306 |
+
def qryword(self, o_word):
|
307 |
+
word = o_word.lower()
|
308 |
+
|
309 |
+
# 查字典, 单字母除外
|
310 |
+
if len(word) > 1 and word in self.cmu: # lookup CMU dict
|
311 |
+
return self.cmu[word][0]
|
312 |
+
|
313 |
+
# 单词仅首字母大写时查找姓名字典
|
314 |
+
if o_word.istitle() and word in self.namedict:
|
315 |
+
return self.namedict[word][0]
|
316 |
+
|
317 |
+
# oov 长度小于等于 3 直接读字母
|
318 |
+
if len(word) <= 3:
|
319 |
+
phones = []
|
320 |
+
for w in word:
|
321 |
+
# 单读 A 发音修正, 此处不存在大写的情况
|
322 |
+
if w == "a":
|
323 |
+
phones.extend(['EY1'])
|
324 |
+
else:
|
325 |
+
phones.extend(self.cmu[w][0])
|
326 |
+
return phones
|
327 |
+
|
328 |
+
# 尝试分离所有格
|
329 |
+
if re.match(r"^([a-z]+)('s)$", word):
|
330 |
+
phones = self.qryword(word[:-2])[:]
|
331 |
+
# P T K F TH HH 无声辅音结尾 's 发 ['S']
|
332 |
+
if phones[-1] in ['P', 'T', 'K', 'F', 'TH', 'HH']:
|
333 |
+
phones.extend(['S'])
|
334 |
+
# S Z SH ZH CH JH 擦声结尾 's 发 ['IH1', 'Z'] 或 ['AH0', 'Z']
|
335 |
+
elif phones[-1] in ['S', 'Z', 'SH', 'ZH', 'CH', 'JH']:
|
336 |
+
phones.extend(['AH0', 'Z'])
|
337 |
+
# B D G DH V M N NG L R W Y 有声辅音结尾 's 发 ['Z']
|
338 |
+
# AH0 AH1 AH2 EY0 EY1 EY2 AE0 AE1 AE2 EH0 EH1 EH2 OW0 OW1 OW2 UH0 UH1 UH2 IY0 IY1 IY2 AA0 AA1 AA2 AO0 AO1 AO2
|
339 |
+
# ER ER0 ER1 ER2 UW0 UW1 UW2 AY0 AY1 AY2 AW0 AW1 AW2 OY0 OY1 OY2 IH IH0 IH1 IH2 元音结尾 's 发 ['Z']
|
340 |
+
else:
|
341 |
+
phones.extend(['Z'])
|
342 |
+
return phones
|
343 |
+
|
344 |
+
# 尝试进行分词,应对复合词
|
345 |
+
comps = wordsegment.segment(word.lower())
|
346 |
+
|
347 |
+
# 无法分词的送回去预测
|
348 |
+
if len(comps)==1:
|
349 |
+
return self.predict(word)
|
350 |
+
|
351 |
+
# 可以分词的递归处理
|
352 |
+
return [phone for comp in comps for phone in self.qryword(comp)]
|
353 |
+
|
354 |
+
|
355 |
+
_g2p = en_G2p()
|
356 |
+
|
357 |
+
|
358 |
+
def g2p(text):
|
359 |
+
# g2p_en 整段推理,剔除不存在的arpa返回
|
360 |
+
phone_list = _g2p(text)
|
361 |
+
phones = [ph if ph != "<unk>" else "UNK" for ph in phone_list if ph not in [" ", "<pad>", "UW", "</s>", "<s>"]]
|
362 |
+
|
363 |
+
return replace_phs(phones)
|
364 |
+
|
365 |
+
|
366 |
+
if __name__ == "__main__":
|
367 |
+
print(g2p("hello"))
|
368 |
+
print(g2p(text_normalize("e.g. I used openai's AI tool to draw a picture.")))
|
369 |
+
print(g2p(text_normalize("In this; paper, we propose 1 DSPGAN, a GAN-based universal vocoder.")))
|
SongBloom/g2p/cn_zh_g2p/nltk_data/corpora/cmudict.zip
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d07cca47fd72ad32ea9d8ad1219f85301eeaf4568f8b6b73747506a71fb5afd6
|
3 |
+
size 896069
|
SongBloom/g2p/cn_zh_g2p/nltk_data/corpora/cmudict/README
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
The Carnegie Mellon Pronouncing Dictionary [cmudict.0.7a]
|
2 |
+
|
3 |
+
ftp://ftp.cs.cmu.edu/project/speech/dict/
|
4 |
+
https://cmusphinx.svn.sourceforge.net/svnroot/cmusphinx/trunk/cmudict/cmudict.0.7a
|
5 |
+
|
6 |
+
Copyright (C) 1993-2008 Carnegie Mellon University. All rights reserved.
|
7 |
+
|
8 |
+
File Format: Each line consists of an uppercased word,
|
9 |
+
a counter (for alternative pronunciations), and a transcription.
|
10 |
+
Vowels are marked for stress (1=primary, 2=secondary, 0=no stress).
|
11 |
+
E.g.: NATURAL 1 N AE1 CH ER0 AH0 L
|
12 |
+
|
13 |
+
The dictionary contains 127069 entries. Of these, 119400 words are assigned
|
14 |
+
a unique pronunciation, 6830 words have two pronunciations, and 839 words have
|
15 |
+
three or more pronunciations. Many of these are fast-speech variants.
|
16 |
+
|
17 |
+
Phonemes: There are 39 phonemes, as shown below:
|
18 |
+
|
19 |
+
Phoneme Example Translation Phoneme Example Translation
|
20 |
+
------- ------- ----------- ------- ------- -----------
|
21 |
+
AA odd AA D AE at AE T
|
22 |
+
AH hut HH AH T AO ought AO T
|
23 |
+
AW cow K AW AY hide HH AY D
|
24 |
+
B be B IY CH cheese CH IY Z
|
25 |
+
D dee D IY DH thee DH IY
|
26 |
+
EH Ed EH D ER hurt HH ER T
|
27 |
+
EY ate EY T F fee F IY
|
28 |
+
G green G R IY N HH he HH IY
|
29 |
+
IH it IH T IY eat IY T
|
30 |
+
JH gee JH IY K key K IY
|
31 |
+
L lee L IY M me M IY
|
32 |
+
N knee N IY NG ping P IH NG
|
33 |
+
OW oat OW T OY toy T OY
|
34 |
+
P pee P IY R read R IY D
|
35 |
+
S sea S IY SH she SH IY
|
36 |
+
T tea T IY TH theta TH EY T AH
|
37 |
+
UH hood HH UH D UW two T UW
|
38 |
+
V vee V IY W we W IY
|
39 |
+
Y yield Y IY L D Z zee Z IY
|
40 |
+
ZH seizure S IY ZH ER
|
41 |
+
|
42 |
+
(For NLTK, entries have been sorted so that, e.g. FIRE 1 and FIRE 2
|
43 |
+
are contiguous, and not separated by FIRE'S 1.)
|
44 |
+
|
45 |
+
Redistribution and use in source and binary forms, with or without
|
46 |
+
modification, are permitted provided that the following conditions
|
47 |
+
are met:
|
48 |
+
|
49 |
+
1. Redistributions of source code must retain the above copyright
|
50 |
+
notice, this list of conditions and the following disclaimer.
|
51 |
+
The contents of this file are deemed to be source code.
|
52 |
+
|
53 |
+
2. Redistributions in binary form must reproduce the above copyright
|
54 |
+
notice, this list of conditions and the following disclaimer in
|
55 |
+
the documentation and/or other materials provided with the
|
56 |
+
distribution.
|
57 |
+
|
58 |
+
This work was supported in part by funding from the Defense Advanced
|
59 |
+
Research Projects Agency, the Office of Naval Research and the National
|
60 |
+
Science Foundation of the United States of America, and by member
|
61 |
+
companies of the Carnegie Mellon Sphinx Speech Consortium. We acknowledge
|
62 |
+
the contributions of many volunteers to the expansion and improvement of
|
63 |
+
this dictionary.
|
64 |
+
|
65 |
+
THIS SOFTWARE IS PROVIDED BY CARNEGIE MELLON UNIVERSITY ``AS IS'' AND
|
66 |
+
ANY EXPRESSED OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO,
|
67 |
+
THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
|
68 |
+
PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL CARNEGIE MELLON UNIVERSITY
|
69 |
+
NOR ITS EMPLOYEES BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
70 |
+
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
71 |
+
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
72 |
+
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
73 |
+
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
74 |
+
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
75 |
+
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
76 |
+
|
SongBloom/g2p/cn_zh_g2p/nltk_data/corpora/cmudict/cmudict
ADDED
The diff for this file is too large to render.
See raw diff
|
|
SongBloom/g2p/cn_zh_g2p/nltk_data/taggers/averaged_perceptron_tagger.zip
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e1f13cf2532daadfd6f3bc481a49859f0b8ea6432ccdcd83e6a49a5f19008de9
|
3 |
+
size 2526731
|
SongBloom/g2p/cn_zh_g2p/nltk_data/taggers/averaged_perceptron_tagger/averaged_perceptron_tagger.pickle
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:25a5a19c7ced7b2bac3831da5bc0afcc2c34e5dd01cd4f361bb799949a696238
|
3 |
+
size 6138625
|
SongBloom/g2p/cn_zh_g2p/opencpop-strict.txt
ADDED
@@ -0,0 +1,429 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
a AA a
|
2 |
+
ai AA ai
|
3 |
+
an AA an
|
4 |
+
ang AA ang
|
5 |
+
ao AA ao
|
6 |
+
ba b a
|
7 |
+
bai b ai
|
8 |
+
ban b an
|
9 |
+
bang b ang
|
10 |
+
bao b ao
|
11 |
+
bei b ei
|
12 |
+
ben b en
|
13 |
+
beng b eng
|
14 |
+
bi b i
|
15 |
+
bian b ian
|
16 |
+
biao b iao
|
17 |
+
bie b ie
|
18 |
+
bin b in
|
19 |
+
bing b ing
|
20 |
+
bo b o
|
21 |
+
bu b u
|
22 |
+
ca c a
|
23 |
+
cai c ai
|
24 |
+
can c an
|
25 |
+
cang c ang
|
26 |
+
cao c ao
|
27 |
+
ce c e
|
28 |
+
cei c ei
|
29 |
+
cen c en
|
30 |
+
ceng c eng
|
31 |
+
cha ch a
|
32 |
+
chai ch ai
|
33 |
+
chan ch an
|
34 |
+
chang ch ang
|
35 |
+
chao ch ao
|
36 |
+
che ch e
|
37 |
+
chen ch en
|
38 |
+
cheng ch eng
|
39 |
+
chi ch ir
|
40 |
+
chong ch ong
|
41 |
+
chou ch ou
|
42 |
+
chu ch u
|
43 |
+
chua ch ua
|
44 |
+
chuai ch uai
|
45 |
+
chuan ch uan
|
46 |
+
chuang ch uang
|
47 |
+
chui ch ui
|
48 |
+
chun ch un
|
49 |
+
chuo ch uo
|
50 |
+
ci c i0
|
51 |
+
cong c ong
|
52 |
+
cou c ou
|
53 |
+
cu c u
|
54 |
+
cuan c uan
|
55 |
+
cui c ui
|
56 |
+
cun c un
|
57 |
+
cuo c uo
|
58 |
+
da d a
|
59 |
+
dai d ai
|
60 |
+
dan d an
|
61 |
+
dang d ang
|
62 |
+
dao d ao
|
63 |
+
de d e
|
64 |
+
dei d ei
|
65 |
+
den d en
|
66 |
+
deng d eng
|
67 |
+
di d i
|
68 |
+
dia d ia
|
69 |
+
dian d ian
|
70 |
+
diao d iao
|
71 |
+
die d ie
|
72 |
+
ding d ing
|
73 |
+
diu d iu
|
74 |
+
dong d ong
|
75 |
+
dou d ou
|
76 |
+
du d u
|
77 |
+
duan d uan
|
78 |
+
dui d ui
|
79 |
+
dun d un
|
80 |
+
duo d uo
|
81 |
+
e EE e
|
82 |
+
ei EE ei
|
83 |
+
en EE en
|
84 |
+
eng EE eng
|
85 |
+
er EE er
|
86 |
+
fa f a
|
87 |
+
fan f an
|
88 |
+
fang f ang
|
89 |
+
fei f ei
|
90 |
+
fen f en
|
91 |
+
feng f eng
|
92 |
+
fo f o
|
93 |
+
fou f ou
|
94 |
+
fu f u
|
95 |
+
ga g a
|
96 |
+
gai g ai
|
97 |
+
gan g an
|
98 |
+
gang g ang
|
99 |
+
gao g ao
|
100 |
+
ge g e
|
101 |
+
gei g ei
|
102 |
+
gen g en
|
103 |
+
geng g eng
|
104 |
+
gong g ong
|
105 |
+
gou g ou
|
106 |
+
gu g u
|
107 |
+
gua g ua
|
108 |
+
guai g uai
|
109 |
+
guan g uan
|
110 |
+
guang g uang
|
111 |
+
gui g ui
|
112 |
+
gun g un
|
113 |
+
guo g uo
|
114 |
+
ha h a
|
115 |
+
hai h ai
|
116 |
+
han h an
|
117 |
+
hang h ang
|
118 |
+
hao h ao
|
119 |
+
he h e
|
120 |
+
hei h ei
|
121 |
+
hen h en
|
122 |
+
heng h eng
|
123 |
+
hong h ong
|
124 |
+
hou h ou
|
125 |
+
hu h u
|
126 |
+
hua h ua
|
127 |
+
huai h uai
|
128 |
+
huan h uan
|
129 |
+
huang h uang
|
130 |
+
hui h ui
|
131 |
+
hun h un
|
132 |
+
huo h uo
|
133 |
+
ji j i
|
134 |
+
jia j ia
|
135 |
+
jian j ian
|
136 |
+
jiang j iang
|
137 |
+
jiao j iao
|
138 |
+
jie j ie
|
139 |
+
jin j in
|
140 |
+
jing j ing
|
141 |
+
jiong j iong
|
142 |
+
jiu j iu
|
143 |
+
ju j v
|
144 |
+
jv j v
|
145 |
+
juan j van
|
146 |
+
jvan j van
|
147 |
+
jue j ve
|
148 |
+
jve j ve
|
149 |
+
jun j vn
|
150 |
+
jvn j vn
|
151 |
+
ka k a
|
152 |
+
kai k ai
|
153 |
+
kan k an
|
154 |
+
kang k ang
|
155 |
+
kao k ao
|
156 |
+
ke k e
|
157 |
+
kei k ei
|
158 |
+
ken k en
|
159 |
+
keng k eng
|
160 |
+
kong k ong
|
161 |
+
kou k ou
|
162 |
+
ku k u
|
163 |
+
kua k ua
|
164 |
+
kuai k uai
|
165 |
+
kuan k uan
|
166 |
+
kuang k uang
|
167 |
+
kui k ui
|
168 |
+
kun k un
|
169 |
+
kuo k uo
|
170 |
+
la l a
|
171 |
+
lai l ai
|
172 |
+
lan l an
|
173 |
+
lang l ang
|
174 |
+
lao l ao
|
175 |
+
le l e
|
176 |
+
lei l ei
|
177 |
+
leng l eng
|
178 |
+
li l i
|
179 |
+
lia l ia
|
180 |
+
lian l ian
|
181 |
+
liang l iang
|
182 |
+
liao l iao
|
183 |
+
lie l ie
|
184 |
+
lin l in
|
185 |
+
ling l ing
|
186 |
+
liu l iu
|
187 |
+
lo l o
|
188 |
+
long l ong
|
189 |
+
lou l ou
|
190 |
+
lu l u
|
191 |
+
luan l uan
|
192 |
+
lun l un
|
193 |
+
luo l uo
|
194 |
+
lv l v
|
195 |
+
lve l ve
|
196 |
+
ma m a
|
197 |
+
mai m ai
|
198 |
+
man m an
|
199 |
+
mang m ang
|
200 |
+
mao m ao
|
201 |
+
me m e
|
202 |
+
mei m ei
|
203 |
+
men m en
|
204 |
+
meng m eng
|
205 |
+
mi m i
|
206 |
+
mian m ian
|
207 |
+
miao m iao
|
208 |
+
mie m ie
|
209 |
+
min m in
|
210 |
+
ming m ing
|
211 |
+
miu m iu
|
212 |
+
mo m o
|
213 |
+
mou m ou
|
214 |
+
mu m u
|
215 |
+
na n a
|
216 |
+
nai n ai
|
217 |
+
nan n an
|
218 |
+
nang n ang
|
219 |
+
nao n ao
|
220 |
+
ne n e
|
221 |
+
nei n ei
|
222 |
+
nen n en
|
223 |
+
neng n eng
|
224 |
+
ni n i
|
225 |
+
nian n ian
|
226 |
+
niang n iang
|
227 |
+
niao n iao
|
228 |
+
nie n ie
|
229 |
+
nin n in
|
230 |
+
ning n ing
|
231 |
+
niu n iu
|
232 |
+
nong n ong
|
233 |
+
nou n ou
|
234 |
+
nu n u
|
235 |
+
nuan n uan
|
236 |
+
nun n un
|
237 |
+
nuo n uo
|
238 |
+
nv n v
|
239 |
+
nve n ve
|
240 |
+
o OO o
|
241 |
+
ou OO ou
|
242 |
+
pa p a
|
243 |
+
pai p ai
|
244 |
+
pan p an
|
245 |
+
pang p ang
|
246 |
+
pao p ao
|
247 |
+
pei p ei
|
248 |
+
pen p en
|
249 |
+
peng p eng
|
250 |
+
pi p i
|
251 |
+
pian p ian
|
252 |
+
piao p iao
|
253 |
+
pie p ie
|
254 |
+
pin p in
|
255 |
+
ping p ing
|
256 |
+
po p o
|
257 |
+
pou p ou
|
258 |
+
pu p u
|
259 |
+
qi q i
|
260 |
+
qia q ia
|
261 |
+
qian q ian
|
262 |
+
qiang q iang
|
263 |
+
qiao q iao
|
264 |
+
qie q ie
|
265 |
+
qin q in
|
266 |
+
qing q ing
|
267 |
+
qiong q iong
|
268 |
+
qiu q iu
|
269 |
+
qu q v
|
270 |
+
qv q v
|
271 |
+
quan q van
|
272 |
+
qvan q van
|
273 |
+
que q ve
|
274 |
+
qve q ve
|
275 |
+
qun q vn
|
276 |
+
qvn q vn
|
277 |
+
ran r an
|
278 |
+
rang r ang
|
279 |
+
rao r ao
|
280 |
+
re r e
|
281 |
+
ren r en
|
282 |
+
reng r eng
|
283 |
+
ri r ir
|
284 |
+
rong r ong
|
285 |
+
rou r ou
|
286 |
+
ru r u
|
287 |
+
rua r ua
|
288 |
+
ruan r uan
|
289 |
+
rui r ui
|
290 |
+
run r un
|
291 |
+
ruo r uo
|
292 |
+
sa s a
|
293 |
+
sai s ai
|
294 |
+
san s an
|
295 |
+
sang s ang
|
296 |
+
sao s ao
|
297 |
+
se s e
|
298 |
+
sen s en
|
299 |
+
seng s eng
|
300 |
+
sha sh a
|
301 |
+
shai sh ai
|
302 |
+
shan sh an
|
303 |
+
shang sh ang
|
304 |
+
shao sh ao
|
305 |
+
she sh e
|
306 |
+
shei sh ei
|
307 |
+
shen sh en
|
308 |
+
sheng sh eng
|
309 |
+
shi sh ir
|
310 |
+
shou sh ou
|
311 |
+
shu sh u
|
312 |
+
shua sh ua
|
313 |
+
shuai sh uai
|
314 |
+
shuan sh uan
|
315 |
+
shuang sh uang
|
316 |
+
shui sh ui
|
317 |
+
shun sh un
|
318 |
+
shuo sh uo
|
319 |
+
si s i0
|
320 |
+
song s ong
|
321 |
+
sou s ou
|
322 |
+
su s u
|
323 |
+
suan s uan
|
324 |
+
sui s ui
|
325 |
+
sun s un
|
326 |
+
suo s uo
|
327 |
+
ta t a
|
328 |
+
tai t ai
|
329 |
+
tan t an
|
330 |
+
tang t ang
|
331 |
+
tao t ao
|
332 |
+
te t e
|
333 |
+
tei t ei
|
334 |
+
teng t eng
|
335 |
+
ti t i
|
336 |
+
tian t ian
|
337 |
+
tiao t iao
|
338 |
+
tie t ie
|
339 |
+
ting t ing
|
340 |
+
tong t ong
|
341 |
+
tou t ou
|
342 |
+
tu t u
|
343 |
+
tuan t uan
|
344 |
+
tui t ui
|
345 |
+
tun t un
|
346 |
+
tuo t uo
|
347 |
+
wa w a
|
348 |
+
wai w ai
|
349 |
+
wan w an
|
350 |
+
wang w ang
|
351 |
+
wei w ei
|
352 |
+
wen w en
|
353 |
+
weng w eng
|
354 |
+
wo w o
|
355 |
+
wu w u
|
356 |
+
xi x i
|
357 |
+
xia x ia
|
358 |
+
xian x ian
|
359 |
+
xiang x iang
|
360 |
+
xiao x iao
|
361 |
+
xie x ie
|
362 |
+
xin x in
|
363 |
+
xing x ing
|
364 |
+
xiong x iong
|
365 |
+
xiu x iu
|
366 |
+
xu x v
|
367 |
+
xv x v
|
368 |
+
xuan x van
|
369 |
+
xvan x van
|
370 |
+
xue x ve
|
371 |
+
xve x ve
|
372 |
+
xun x vn
|
373 |
+
xvn x vn
|
374 |
+
ya y a
|
375 |
+
yan y En
|
376 |
+
yang y ang
|
377 |
+
yao y ao
|
378 |
+
ye y E
|
379 |
+
yi y i
|
380 |
+
yin y in
|
381 |
+
ying y ing
|
382 |
+
yo y o
|
383 |
+
yong y ong
|
384 |
+
you y ou
|
385 |
+
yu y v
|
386 |
+
yv y v
|
387 |
+
yuan y van
|
388 |
+
yvan y van
|
389 |
+
yue y ve
|
390 |
+
yve y ve
|
391 |
+
yun y vn
|
392 |
+
yvn y vn
|
393 |
+
za z a
|
394 |
+
zai z ai
|
395 |
+
zan z an
|
396 |
+
zang z ang
|
397 |
+
zao z ao
|
398 |
+
ze z e
|
399 |
+
zei z ei
|
400 |
+
zen z en
|
401 |
+
zeng z eng
|
402 |
+
zha zh a
|
403 |
+
zhai zh ai
|
404 |
+
zhan zh an
|
405 |
+
zhang zh ang
|
406 |
+
zhao zh ao
|
407 |
+
zhe zh e
|
408 |
+
zhei zh ei
|
409 |
+
zhen zh en
|
410 |
+
zheng zh eng
|
411 |
+
zhi zh ir
|
412 |
+
zhong zh ong
|
413 |
+
zhou zh ou
|
414 |
+
zhu zh u
|
415 |
+
zhua zh ua
|
416 |
+
zhuai zh uai
|
417 |
+
zhuan zh uan
|
418 |
+
zhuang zh uang
|
419 |
+
zhui zh ui
|
420 |
+
zhun zh un
|
421 |
+
zhuo zh uo
|
422 |
+
zi z i0
|
423 |
+
zong z ong
|
424 |
+
zou z ou
|
425 |
+
zu z u
|
426 |
+
zuan z uan
|
427 |
+
zui z ui
|
428 |
+
zun z un
|
429 |
+
zuo z uo
|
SongBloom/g2p/cn_zh_g2p/symbols.py
ADDED
@@ -0,0 +1,401 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
# punctuation = ['!', '?', '…', ",", ".","@"]#@是SP停顿
|
4 |
+
punctuation = ["!", "?", "…", ",", "."] # @是SP停顿
|
5 |
+
punctuation.append("-")
|
6 |
+
pu_symbols = punctuation + ["SP", "SP2", "SP3", "UNK"]
|
7 |
+
# pu_symbols = punctuation + ["SP", 'SP2', 'SP3','SP4', "UNK"]
|
8 |
+
pad = "_"
|
9 |
+
|
10 |
+
c = [
|
11 |
+
"AA",
|
12 |
+
"EE",
|
13 |
+
"OO",
|
14 |
+
"b",
|
15 |
+
"c",
|
16 |
+
"ch",
|
17 |
+
"d",
|
18 |
+
"f",
|
19 |
+
"g",
|
20 |
+
"h",
|
21 |
+
"j",
|
22 |
+
"k",
|
23 |
+
"l",
|
24 |
+
"m",
|
25 |
+
"n",
|
26 |
+
"p",
|
27 |
+
"q",
|
28 |
+
"r",
|
29 |
+
"s",
|
30 |
+
"sh",
|
31 |
+
"t",
|
32 |
+
"w",
|
33 |
+
"x",
|
34 |
+
"y",
|
35 |
+
"z",
|
36 |
+
"zh",
|
37 |
+
]
|
38 |
+
v = [
|
39 |
+
"E1",
|
40 |
+
"En1",
|
41 |
+
"a1",
|
42 |
+
"ai1",
|
43 |
+
"an1",
|
44 |
+
"ang1",
|
45 |
+
"ao1",
|
46 |
+
"e1",
|
47 |
+
"ei1",
|
48 |
+
"en1",
|
49 |
+
"eng1",
|
50 |
+
"er1",
|
51 |
+
"i1",
|
52 |
+
"i01",
|
53 |
+
"ia1",
|
54 |
+
"ian1",
|
55 |
+
"iang1",
|
56 |
+
"iao1",
|
57 |
+
"ie1",
|
58 |
+
"in1",
|
59 |
+
"ing1",
|
60 |
+
"iong1",
|
61 |
+
"ir1",
|
62 |
+
"iu1",
|
63 |
+
"o1",
|
64 |
+
"ong1",
|
65 |
+
"ou1",
|
66 |
+
"u1",
|
67 |
+
"ua1",
|
68 |
+
"uai1",
|
69 |
+
"uan1",
|
70 |
+
"uang1",
|
71 |
+
"ui1",
|
72 |
+
"un1",
|
73 |
+
"uo1",
|
74 |
+
"v1",
|
75 |
+
"van1",
|
76 |
+
"ve1",
|
77 |
+
"vn1",
|
78 |
+
"E2",
|
79 |
+
"En2",
|
80 |
+
"a2",
|
81 |
+
"ai2",
|
82 |
+
"an2",
|
83 |
+
"ang2",
|
84 |
+
"ao2",
|
85 |
+
"e2",
|
86 |
+
"ei2",
|
87 |
+
"en2",
|
88 |
+
"eng2",
|
89 |
+
"er2",
|
90 |
+
"i2",
|
91 |
+
"i02",
|
92 |
+
"ia2",
|
93 |
+
"ian2",
|
94 |
+
"iang2",
|
95 |
+
"iao2",
|
96 |
+
"ie2",
|
97 |
+
"in2",
|
98 |
+
"ing2",
|
99 |
+
"iong2",
|
100 |
+
"ir2",
|
101 |
+
"iu2",
|
102 |
+
"o2",
|
103 |
+
"ong2",
|
104 |
+
"ou2",
|
105 |
+
"u2",
|
106 |
+
"ua2",
|
107 |
+
"uai2",
|
108 |
+
"uan2",
|
109 |
+
"uang2",
|
110 |
+
"ui2",
|
111 |
+
"un2",
|
112 |
+
"uo2",
|
113 |
+
"v2",
|
114 |
+
"van2",
|
115 |
+
"ve2",
|
116 |
+
"vn2",
|
117 |
+
"E3",
|
118 |
+
"En3",
|
119 |
+
"a3",
|
120 |
+
"ai3",
|
121 |
+
"an3",
|
122 |
+
"ang3",
|
123 |
+
"ao3",
|
124 |
+
"e3",
|
125 |
+
"ei3",
|
126 |
+
"en3",
|
127 |
+
"eng3",
|
128 |
+
"er3",
|
129 |
+
"i3",
|
130 |
+
"i03",
|
131 |
+
"ia3",
|
132 |
+
"ian3",
|
133 |
+
"iang3",
|
134 |
+
"iao3",
|
135 |
+
"ie3",
|
136 |
+
"in3",
|
137 |
+
"ing3",
|
138 |
+
"iong3",
|
139 |
+
"ir3",
|
140 |
+
"iu3",
|
141 |
+
"o3",
|
142 |
+
"ong3",
|
143 |
+
"ou3",
|
144 |
+
"u3",
|
145 |
+
"ua3",
|
146 |
+
"uai3",
|
147 |
+
"uan3",
|
148 |
+
"uang3",
|
149 |
+
"ui3",
|
150 |
+
"un3",
|
151 |
+
"uo3",
|
152 |
+
"v3",
|
153 |
+
"van3",
|
154 |
+
"ve3",
|
155 |
+
"vn3",
|
156 |
+
"E4",
|
157 |
+
"En4",
|
158 |
+
"a4",
|
159 |
+
"ai4",
|
160 |
+
"an4",
|
161 |
+
"ang4",
|
162 |
+
"ao4",
|
163 |
+
"e4",
|
164 |
+
"ei4",
|
165 |
+
"en4",
|
166 |
+
"eng4",
|
167 |
+
"er4",
|
168 |
+
"i4",
|
169 |
+
"i04",
|
170 |
+
"ia4",
|
171 |
+
"ian4",
|
172 |
+
"iang4",
|
173 |
+
"iao4",
|
174 |
+
"ie4",
|
175 |
+
"in4",
|
176 |
+
"ing4",
|
177 |
+
"iong4",
|
178 |
+
"ir4",
|
179 |
+
"iu4",
|
180 |
+
"o4",
|
181 |
+
"ong4",
|
182 |
+
"ou4",
|
183 |
+
"u4",
|
184 |
+
"ua4",
|
185 |
+
"uai4",
|
186 |
+
"uan4",
|
187 |
+
"uang4",
|
188 |
+
"ui4",
|
189 |
+
"un4",
|
190 |
+
"uo4",
|
191 |
+
"v4",
|
192 |
+
"van4",
|
193 |
+
"ve4",
|
194 |
+
"vn4",
|
195 |
+
"E5",
|
196 |
+
"En5",
|
197 |
+
"a5",
|
198 |
+
"ai5",
|
199 |
+
"an5",
|
200 |
+
"ang5",
|
201 |
+
"ao5",
|
202 |
+
"e5",
|
203 |
+
"ei5",
|
204 |
+
"en5",
|
205 |
+
"eng5",
|
206 |
+
"er5",
|
207 |
+
"i5",
|
208 |
+
"i05",
|
209 |
+
"ia5",
|
210 |
+
"ian5",
|
211 |
+
"iang5",
|
212 |
+
"iao5",
|
213 |
+
"ie5",
|
214 |
+
"in5",
|
215 |
+
"ing5",
|
216 |
+
"iong5",
|
217 |
+
"ir5",
|
218 |
+
"iu5",
|
219 |
+
"o5",
|
220 |
+
"ong5",
|
221 |
+
"ou5",
|
222 |
+
"u5",
|
223 |
+
"ua5",
|
224 |
+
"uai5",
|
225 |
+
"uan5",
|
226 |
+
"uang5",
|
227 |
+
"ui5",
|
228 |
+
"un5",
|
229 |
+
"uo5",
|
230 |
+
"v5",
|
231 |
+
"van5",
|
232 |
+
"ve5",
|
233 |
+
"vn5",
|
234 |
+
]
|
235 |
+
|
236 |
+
v_without_tone = [
|
237 |
+
"E",
|
238 |
+
"En",
|
239 |
+
"a",
|
240 |
+
"ai",
|
241 |
+
"an",
|
242 |
+
"ang",
|
243 |
+
"ao",
|
244 |
+
"e",
|
245 |
+
"ei",
|
246 |
+
"en",
|
247 |
+
"eng",
|
248 |
+
"er",
|
249 |
+
"i",
|
250 |
+
"i0",
|
251 |
+
"ia",
|
252 |
+
"ian",
|
253 |
+
"iang",
|
254 |
+
"iao",
|
255 |
+
"ie",
|
256 |
+
"in",
|
257 |
+
"ing",
|
258 |
+
"iong",
|
259 |
+
"ir",
|
260 |
+
"iu",
|
261 |
+
"o",
|
262 |
+
"ong",
|
263 |
+
"ou",
|
264 |
+
"u",
|
265 |
+
"ua",
|
266 |
+
"uai",
|
267 |
+
"uan",
|
268 |
+
"uang",
|
269 |
+
"ui",
|
270 |
+
"un",
|
271 |
+
"uo",
|
272 |
+
"v",
|
273 |
+
"van",
|
274 |
+
"ve",
|
275 |
+
"vn",
|
276 |
+
]
|
277 |
+
|
278 |
+
# japanese
|
279 |
+
ja_symbols = [
|
280 |
+
"I",
|
281 |
+
"N",
|
282 |
+
"U",
|
283 |
+
"a",
|
284 |
+
"b",
|
285 |
+
"by",
|
286 |
+
"ch",
|
287 |
+
"cl",
|
288 |
+
"d",
|
289 |
+
"dy",
|
290 |
+
"e",
|
291 |
+
"f",
|
292 |
+
"g",
|
293 |
+
"gy",
|
294 |
+
"h",
|
295 |
+
"hy",
|
296 |
+
"i",
|
297 |
+
"j",
|
298 |
+
"k",
|
299 |
+
"ky",
|
300 |
+
"m",
|
301 |
+
"my",
|
302 |
+
"n",
|
303 |
+
"ny",
|
304 |
+
"o",
|
305 |
+
"p",
|
306 |
+
"py",
|
307 |
+
"r",
|
308 |
+
"ry",
|
309 |
+
"s",
|
310 |
+
"sh",
|
311 |
+
"t",
|
312 |
+
"ts",
|
313 |
+
"u",
|
314 |
+
"v",
|
315 |
+
"w",
|
316 |
+
"y",
|
317 |
+
"z",
|
318 |
+
# "[", #上升调型
|
319 |
+
# "]", #下降调型
|
320 |
+
# "$", #结束符
|
321 |
+
# "^", #开始符
|
322 |
+
]
|
323 |
+
|
324 |
+
arpa = {
|
325 |
+
"AH0",
|
326 |
+
"S",
|
327 |
+
"AH1",
|
328 |
+
"EY2",
|
329 |
+
"AE2",
|
330 |
+
"EH0",
|
331 |
+
"OW2",
|
332 |
+
"UH0",
|
333 |
+
"NG",
|
334 |
+
"B",
|
335 |
+
"G",
|
336 |
+
"AY0",
|
337 |
+
"M",
|
338 |
+
"AA0",
|
339 |
+
"F",
|
340 |
+
"AO0",
|
341 |
+
"ER2",
|
342 |
+
"UH1",
|
343 |
+
"IY1",
|
344 |
+
"AH2",
|
345 |
+
"DH",
|
346 |
+
"IY0",
|
347 |
+
"EY1",
|
348 |
+
"IH0",
|
349 |
+
"K",
|
350 |
+
"N",
|
351 |
+
"W",
|
352 |
+
"IY2",
|
353 |
+
"T",
|
354 |
+
"AA1",
|
355 |
+
"ER1",
|
356 |
+
"EH2",
|
357 |
+
"OY0",
|
358 |
+
"UH2",
|
359 |
+
"UW1",
|
360 |
+
"Z",
|
361 |
+
"AW2",
|
362 |
+
"AW1",
|
363 |
+
"V",
|
364 |
+
"UW2",
|
365 |
+
"AA2",
|
366 |
+
"ER",
|
367 |
+
"AW0",
|
368 |
+
"UW0",
|
369 |
+
"R",
|
370 |
+
"OW1",
|
371 |
+
"EH1",
|
372 |
+
"ZH",
|
373 |
+
"AE0",
|
374 |
+
"IH2",
|
375 |
+
"IH",
|
376 |
+
"Y",
|
377 |
+
"JH",
|
378 |
+
"P",
|
379 |
+
"AY1",
|
380 |
+
"EY0",
|
381 |
+
"OY2",
|
382 |
+
"TH",
|
383 |
+
"HH",
|
384 |
+
"D",
|
385 |
+
"ER0",
|
386 |
+
"CH",
|
387 |
+
"AO1",
|
388 |
+
"AE1",
|
389 |
+
"AO2",
|
390 |
+
"OY1",
|
391 |
+
"AY2",
|
392 |
+
"IH1",
|
393 |
+
"OW0",
|
394 |
+
"L",
|
395 |
+
"SH",
|
396 |
+
}
|
397 |
+
|
398 |
+
symbols = [pad] + c + v + ja_symbols + pu_symbols + list(arpa)
|
399 |
+
symbols = sorted(set(symbols))
|
400 |
+
if __name__ == "__main__":
|
401 |
+
print(len(symbols))
|
SongBloom/g2p/cn_zh_g2p/tone_sandhi.py
ADDED
@@ -0,0 +1,806 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
from typing import List
|
15 |
+
from typing import Tuple
|
16 |
+
|
17 |
+
import jieba_fast as jieba
|
18 |
+
from pypinyin import lazy_pinyin
|
19 |
+
from pypinyin import Style
|
20 |
+
|
21 |
+
|
22 |
+
class ToneSandhi:
|
23 |
+
def __init__(self):
|
24 |
+
self.must_neural_tone_words = {
|
25 |
+
"麻烦",
|
26 |
+
"麻利",
|
27 |
+
"鸳鸯",
|
28 |
+
"高粱",
|
29 |
+
"骨头",
|
30 |
+
"骆驼",
|
31 |
+
"马虎",
|
32 |
+
"首饰",
|
33 |
+
"馒头",
|
34 |
+
"馄饨",
|
35 |
+
"风筝",
|
36 |
+
"难为",
|
37 |
+
"队伍",
|
38 |
+
"阔气",
|
39 |
+
"闺女",
|
40 |
+
"门道",
|
41 |
+
"锄头",
|
42 |
+
"铺盖",
|
43 |
+
"铃铛",
|
44 |
+
"铁匠",
|
45 |
+
"钥匙",
|
46 |
+
"里脊",
|
47 |
+
"里头",
|
48 |
+
"部分",
|
49 |
+
"那么",
|
50 |
+
"道士",
|
51 |
+
"造化",
|
52 |
+
"迷糊",
|
53 |
+
"连累",
|
54 |
+
"这么",
|
55 |
+
"这个",
|
56 |
+
"运气",
|
57 |
+
"过去",
|
58 |
+
"软和",
|
59 |
+
"转悠",
|
60 |
+
"踏实",
|
61 |
+
"跳蚤",
|
62 |
+
"跟头",
|
63 |
+
"趔趄",
|
64 |
+
"财主",
|
65 |
+
"豆腐",
|
66 |
+
"讲究",
|
67 |
+
"记性",
|
68 |
+
"记号",
|
69 |
+
"认识",
|
70 |
+
"规矩",
|
71 |
+
"见识",
|
72 |
+
"裁缝",
|
73 |
+
"补丁",
|
74 |
+
"衣裳",
|
75 |
+
"衣服",
|
76 |
+
"衙门",
|
77 |
+
"街坊",
|
78 |
+
"行李",
|
79 |
+
"行当",
|
80 |
+
"蛤蟆",
|
81 |
+
"蘑菇",
|
82 |
+
"薄荷",
|
83 |
+
"葫芦",
|
84 |
+
"葡萄",
|
85 |
+
"萝卜",
|
86 |
+
"荸荠",
|
87 |
+
"苗条",
|
88 |
+
"苗头",
|
89 |
+
"苍蝇",
|
90 |
+
"芝麻",
|
91 |
+
"舒服",
|
92 |
+
"舒坦",
|
93 |
+
"舌头",
|
94 |
+
"自在",
|
95 |
+
"膏药",
|
96 |
+
"脾气",
|
97 |
+
"脑袋",
|
98 |
+
"脊梁",
|
99 |
+
"能耐",
|
100 |
+
"胳膊",
|
101 |
+
"胭脂",
|
102 |
+
"胡萝",
|
103 |
+
"胡琴",
|
104 |
+
"胡同",
|
105 |
+
"聪明",
|
106 |
+
"耽误",
|
107 |
+
"耽搁",
|
108 |
+
"耷拉",
|
109 |
+
"耳朵",
|
110 |
+
"老爷",
|
111 |
+
"老实",
|
112 |
+
"老婆",
|
113 |
+
"老头",
|
114 |
+
"老太",
|
115 |
+
"翻腾",
|
116 |
+
"罗嗦",
|
117 |
+
"罐头",
|
118 |
+
"编辑",
|
119 |
+
"结实",
|
120 |
+
"红火",
|
121 |
+
"累赘",
|
122 |
+
"糨糊",
|
123 |
+
"糊涂",
|
124 |
+
"精神",
|
125 |
+
"粮食",
|
126 |
+
"簸箕",
|
127 |
+
"篱笆",
|
128 |
+
"算计",
|
129 |
+
"算盘",
|
130 |
+
"答应",
|
131 |
+
"笤帚",
|
132 |
+
"笑语",
|
133 |
+
"笑话",
|
134 |
+
"窟窿",
|
135 |
+
"窝囊",
|
136 |
+
"窗户",
|
137 |
+
"稳当",
|
138 |
+
"稀罕",
|
139 |
+
"称呼",
|
140 |
+
"秧歌",
|
141 |
+
"秀气",
|
142 |
+
"秀才",
|
143 |
+
"福气",
|
144 |
+
"祖宗",
|
145 |
+
"砚台",
|
146 |
+
"码头",
|
147 |
+
"石榴",
|
148 |
+
"石头",
|
149 |
+
"石匠",
|
150 |
+
"知识",
|
151 |
+
"眼睛",
|
152 |
+
"眯缝",
|
153 |
+
"眨巴",
|
154 |
+
"眉毛",
|
155 |
+
"相声",
|
156 |
+
"盘算",
|
157 |
+
"白净",
|
158 |
+
"痢疾",
|
159 |
+
"痛快",
|
160 |
+
"疟疾",
|
161 |
+
"疙瘩",
|
162 |
+
"疏忽",
|
163 |
+
"畜生",
|
164 |
+
"生意",
|
165 |
+
"甘蔗",
|
166 |
+
"琵琶",
|
167 |
+
"琢磨",
|
168 |
+
"琉璃",
|
169 |
+
"玻璃",
|
170 |
+
"玫瑰",
|
171 |
+
"玄乎",
|
172 |
+
"狐狸",
|
173 |
+
"状元",
|
174 |
+
"特务",
|
175 |
+
"牲口",
|
176 |
+
"牙碜",
|
177 |
+
"牌楼",
|
178 |
+
"爽快",
|
179 |
+
"爱人",
|
180 |
+
"热闹",
|
181 |
+
"烧饼",
|
182 |
+
"烟筒",
|
183 |
+
"烂糊",
|
184 |
+
"点心",
|
185 |
+
"炊帚",
|
186 |
+
"灯笼",
|
187 |
+
"火候",
|
188 |
+
"漂亮",
|
189 |
+
"滑溜",
|
190 |
+
"溜达",
|
191 |
+
"温和",
|
192 |
+
"清楚",
|
193 |
+
"消息",
|
194 |
+
"浪头",
|
195 |
+
"活泼",
|
196 |
+
"比方",
|
197 |
+
"正经",
|
198 |
+
"欺负",
|
199 |
+
"模糊",
|
200 |
+
"槟榔",
|
201 |
+
"棺材",
|
202 |
+
"棒槌",
|
203 |
+
"棉花",
|
204 |
+
"核桃",
|
205 |
+
"栅栏",
|
206 |
+
"柴火",
|
207 |
+
"架势",
|
208 |
+
"枕头",
|
209 |
+
"���杷",
|
210 |
+
"机灵",
|
211 |
+
"本事",
|
212 |
+
"木头",
|
213 |
+
"木匠",
|
214 |
+
"朋友",
|
215 |
+
"月饼",
|
216 |
+
"月亮",
|
217 |
+
"暖和",
|
218 |
+
"明白",
|
219 |
+
"时候",
|
220 |
+
"新鲜",
|
221 |
+
"故事",
|
222 |
+
"收拾",
|
223 |
+
"收成",
|
224 |
+
"提防",
|
225 |
+
"挖苦",
|
226 |
+
"挑剔",
|
227 |
+
"指甲",
|
228 |
+
"指头",
|
229 |
+
"拾掇",
|
230 |
+
"拳头",
|
231 |
+
"拨弄",
|
232 |
+
"招牌",
|
233 |
+
"招呼",
|
234 |
+
"抬举",
|
235 |
+
"护士",
|
236 |
+
"折腾",
|
237 |
+
"扫帚",
|
238 |
+
"打量",
|
239 |
+
"打算",
|
240 |
+
"打点",
|
241 |
+
"打扮",
|
242 |
+
"打听",
|
243 |
+
"打发",
|
244 |
+
"扎实",
|
245 |
+
"扁担",
|
246 |
+
"戒指",
|
247 |
+
"懒得",
|
248 |
+
"意识",
|
249 |
+
"意思",
|
250 |
+
"情形",
|
251 |
+
"悟性",
|
252 |
+
"怪物",
|
253 |
+
"思量",
|
254 |
+
"怎么",
|
255 |
+
"念头",
|
256 |
+
"念叨",
|
257 |
+
"快活",
|
258 |
+
"忙活",
|
259 |
+
"志气",
|
260 |
+
"心思",
|
261 |
+
"得罪",
|
262 |
+
"张罗",
|
263 |
+
"弟兄",
|
264 |
+
"开通",
|
265 |
+
"应酬",
|
266 |
+
"庄稼",
|
267 |
+
"干事",
|
268 |
+
"帮手",
|
269 |
+
"帐篷",
|
270 |
+
"希罕",
|
271 |
+
"师父",
|
272 |
+
"师傅",
|
273 |
+
"巴结",
|
274 |
+
"巴掌",
|
275 |
+
"差事",
|
276 |
+
"工夫",
|
277 |
+
"岁数",
|
278 |
+
"屁股",
|
279 |
+
"尾巴",
|
280 |
+
"少爷",
|
281 |
+
"小气",
|
282 |
+
"小伙",
|
283 |
+
"将就",
|
284 |
+
"对头",
|
285 |
+
"对付",
|
286 |
+
"寡妇",
|
287 |
+
"家伙",
|
288 |
+
"客气",
|
289 |
+
"实在",
|
290 |
+
"官司",
|
291 |
+
"学问",
|
292 |
+
"学生",
|
293 |
+
"字号",
|
294 |
+
"嫁妆",
|
295 |
+
"媳妇",
|
296 |
+
"媒人",
|
297 |
+
"婆家",
|
298 |
+
"娘家",
|
299 |
+
"委屈",
|
300 |
+
"姑娘",
|
301 |
+
"姐夫",
|
302 |
+
"妯娌",
|
303 |
+
"妥当",
|
304 |
+
"妖精",
|
305 |
+
"奴才",
|
306 |
+
"女婿",
|
307 |
+
"头发",
|
308 |
+
"太阳",
|
309 |
+
"大爷",
|
310 |
+
"大方",
|
311 |
+
"大意",
|
312 |
+
"大夫",
|
313 |
+
"多少",
|
314 |
+
"多么",
|
315 |
+
"外甥",
|
316 |
+
"壮实",
|
317 |
+
"地道",
|
318 |
+
"地方",
|
319 |
+
"在乎",
|
320 |
+
"困难",
|
321 |
+
"嘴巴",
|
322 |
+
"嘱咐",
|
323 |
+
"嘟囔",
|
324 |
+
"嘀咕",
|
325 |
+
"喜欢",
|
326 |
+
"喇嘛",
|
327 |
+
"喇叭",
|
328 |
+
"商量",
|
329 |
+
"唾沫",
|
330 |
+
"哑巴",
|
331 |
+
"哈欠",
|
332 |
+
"哆嗦",
|
333 |
+
"咳嗽",
|
334 |
+
"和尚",
|
335 |
+
"告诉",
|
336 |
+
"告示",
|
337 |
+
"含糊",
|
338 |
+
"吓唬",
|
339 |
+
"后头",
|
340 |
+
"名字",
|
341 |
+
"名堂",
|
342 |
+
"合同",
|
343 |
+
"吆喝",
|
344 |
+
"叫唤",
|
345 |
+
"口袋",
|
346 |
+
"厚道",
|
347 |
+
"厉害",
|
348 |
+
"千斤",
|
349 |
+
"包袱",
|
350 |
+
"包涵",
|
351 |
+
"匀称",
|
352 |
+
"勤快",
|
353 |
+
"动静",
|
354 |
+
"动弹",
|
355 |
+
"功夫",
|
356 |
+
"力气",
|
357 |
+
"前头",
|
358 |
+
"刺猬",
|
359 |
+
"刺激",
|
360 |
+
"别扭",
|
361 |
+
"利落",
|
362 |
+
"利索",
|
363 |
+
"利害",
|
364 |
+
"分析",
|
365 |
+
"出息",
|
366 |
+
"凑合",
|
367 |
+
"凉快",
|
368 |
+
"冷战",
|
369 |
+
"冤枉",
|
370 |
+
"冒失",
|
371 |
+
"养活",
|
372 |
+
"关系",
|
373 |
+
"先生",
|
374 |
+
"兄弟",
|
375 |
+
"便宜",
|
376 |
+
"使唤",
|
377 |
+
"佩服",
|
378 |
+
"作坊",
|
379 |
+
"体面",
|
380 |
+
"位置",
|
381 |
+
"似的",
|
382 |
+
"伙计",
|
383 |
+
"休息",
|
384 |
+
"什么",
|
385 |
+
"人家",
|
386 |
+
"亲戚",
|
387 |
+
"亲家",
|
388 |
+
"交情",
|
389 |
+
"云彩",
|
390 |
+
"事情",
|
391 |
+
"买卖",
|
392 |
+
"主意",
|
393 |
+
"丫头",
|
394 |
+
"丧气",
|
395 |
+
"两口",
|
396 |
+
"东西",
|
397 |
+
"东家",
|
398 |
+
"世故",
|
399 |
+
"不由",
|
400 |
+
"不在",
|
401 |
+
"下水",
|
402 |
+
"下巴",
|
403 |
+
"上头",
|
404 |
+
"上司",
|
405 |
+
"丈夫",
|
406 |
+
"丈人",
|
407 |
+
"一辈",
|
408 |
+
"那个",
|
409 |
+
"菩萨",
|
410 |
+
"父亲",
|
411 |
+
"母亲",
|
412 |
+
"咕噜",
|
413 |
+
"邋遢",
|
414 |
+
"费用",
|
415 |
+
"冤家",
|
416 |
+
"甜头",
|
417 |
+
"介绍",
|
418 |
+
"荒唐",
|
419 |
+
"大人",
|
420 |
+
"泥鳅",
|
421 |
+
"幸福",
|
422 |
+
"熟悉",
|
423 |
+
"计划",
|
424 |
+
"扑腾",
|
425 |
+
"蜡烛",
|
426 |
+
"姥爷",
|
427 |
+
"照顾",
|
428 |
+
"喉咙",
|
429 |
+
"吉他",
|
430 |
+
"弄堂",
|
431 |
+
"蚂蚱",
|
432 |
+
"凤凰",
|
433 |
+
"拖沓",
|
434 |
+
"寒碜",
|
435 |
+
"糟蹋",
|
436 |
+
"倒腾",
|
437 |
+
"报复",
|
438 |
+
"逻辑",
|
439 |
+
"盘缠",
|
440 |
+
"喽啰",
|
441 |
+
"牢骚",
|
442 |
+
"咖喱",
|
443 |
+
"扫把",
|
444 |
+
"惦记",
|
445 |
+
}
|
446 |
+
self.must_not_neural_tone_words = {
|
447 |
+
"男子",
|
448 |
+
"女子",
|
449 |
+
"分子",
|
450 |
+
"原子",
|
451 |
+
"量子",
|
452 |
+
"莲子",
|
453 |
+
"石子",
|
454 |
+
"瓜子",
|
455 |
+
"电子",
|
456 |
+
"人人",
|
457 |
+
"虎虎",
|
458 |
+
"幺幺",
|
459 |
+
"干嘛",
|
460 |
+
"学子",
|
461 |
+
"哈哈",
|
462 |
+
"数数",
|
463 |
+
"袅袅",
|
464 |
+
"局地",
|
465 |
+
"以下",
|
466 |
+
"娃哈哈",
|
467 |
+
"花花草草",
|
468 |
+
"留得",
|
469 |
+
"耕地",
|
470 |
+
"想想",
|
471 |
+
"熙熙",
|
472 |
+
"攘攘",
|
473 |
+
"卵子",
|
474 |
+
"死死",
|
475 |
+
"冉冉",
|
476 |
+
"恳恳",
|
477 |
+
"佼佼",
|
478 |
+
"吵吵",
|
479 |
+
"打打",
|
480 |
+
"考考",
|
481 |
+
"整整",
|
482 |
+
"莘莘",
|
483 |
+
"落地",
|
484 |
+
"算子",
|
485 |
+
"家家户户",
|
486 |
+
"青青",
|
487 |
+
}
|
488 |
+
self.punc = ":,;。?!“”‘’':,;.?!"
|
489 |
+
|
490 |
+
# the meaning of jieba pos tag: https://blog.csdn.net/weixin_44174352/article/details/113731041
|
491 |
+
# e.g.
|
492 |
+
# word: "家里"
|
493 |
+
# pos: "s"
|
494 |
+
# finals: ['ia1', 'i3']
|
495 |
+
def _neural_sandhi(self, word: str, pos: str, finals: List[str]) -> List[str]:
|
496 |
+
# reduplication words for n. and v. e.g. 奶奶, 试试, 旺旺
|
497 |
+
for j, item in enumerate(word):
|
498 |
+
if (
|
499 |
+
j - 1 >= 0
|
500 |
+
and item == word[j - 1]
|
501 |
+
and pos[0] in {"n", "v", "a"}
|
502 |
+
and word not in self.must_not_neural_tone_words
|
503 |
+
):
|
504 |
+
finals[j] = finals[j][:-1] + "5"
|
505 |
+
ge_idx = word.find("个")
|
506 |
+
if len(word) >= 1 and word[-1] in "吧呢哈啊呐噻嘛吖嗨呐哦哒额滴哩哟喽啰耶喔诶":
|
507 |
+
finals[-1] = finals[-1][:-1] + "5"
|
508 |
+
elif len(word) >= 1 and word[-1] in "的地得":
|
509 |
+
finals[-1] = finals[-1][:-1] + "5"
|
510 |
+
# e.g. 走了, 看着, 去过
|
511 |
+
elif len(word) == 1 and word in "了着过" and pos in {"ul", "uz", "ug"}:
|
512 |
+
finals[-1] = finals[-1][:-1] + "5"
|
513 |
+
elif (
|
514 |
+
len(word) > 1
|
515 |
+
and word[-1] in "们子"
|
516 |
+
and pos in {"r", "n"}
|
517 |
+
and word not in self.must_not_neural_tone_words
|
518 |
+
):
|
519 |
+
finals[-1] = finals[-1][:-1] + "5"
|
520 |
+
# e.g. 桌上, 地下, 家里
|
521 |
+
elif len(word) > 1 and word[-1] in "上下里" and pos in {"s", "l", "f"}:
|
522 |
+
finals[-1] = finals[-1][:-1] + "5"
|
523 |
+
# e.g. 上来, 下去
|
524 |
+
elif len(word) > 1 and word[-1] in "来去" and word[-2] in "上下进出回过起开":
|
525 |
+
finals[-1] = finals[-1][:-1] + "5"
|
526 |
+
# 个做量词
|
527 |
+
elif (
|
528 |
+
ge_idx >= 1
|
529 |
+
and (word[ge_idx - 1].isnumeric() or word[ge_idx - 1] in "几有两半多各整每做是")
|
530 |
+
) or word == "个":
|
531 |
+
finals[ge_idx] = finals[ge_idx][:-1] + "5"
|
532 |
+
else:
|
533 |
+
if (
|
534 |
+
word in self.must_neural_tone_words
|
535 |
+
or word[-2:] in self.must_neural_tone_words
|
536 |
+
):
|
537 |
+
finals[-1] = finals[-1][:-1] + "5"
|
538 |
+
|
539 |
+
word_list = self._split_word(word)
|
540 |
+
finals_list = [finals[: len(word_list[0])], finals[len(word_list[0]) :]]
|
541 |
+
for i, word in enumerate(word_list):
|
542 |
+
# conventional neural in Chinese
|
543 |
+
if (
|
544 |
+
word in self.must_neural_tone_words
|
545 |
+
or word[-2:] in self.must_neural_tone_words
|
546 |
+
):
|
547 |
+
finals_list[i][-1] = finals_list[i][-1][:-1] + "5"
|
548 |
+
finals = sum(finals_list, [])
|
549 |
+
return finals
|
550 |
+
|
551 |
+
def _bu_sandhi(self, word: str, finals: List[str]) -> List[str]:
|
552 |
+
# e.g. 看不懂
|
553 |
+
if len(word) == 3 and word[1] == "不":
|
554 |
+
finals[1] = finals[1][:-1] + "5"
|
555 |
+
else:
|
556 |
+
for i, char in enumerate(word):
|
557 |
+
# "不" before tone4 should be bu2, e.g. 不怕
|
558 |
+
if char == "不" and i + 1 < len(word) and finals[i + 1][-1] == "4":
|
559 |
+
finals[i] = finals[i][:-1] + "2"
|
560 |
+
return finals
|
561 |
+
|
562 |
+
def _yi_sandhi(self, word: str, finals: List[str]) -> List[str]:
|
563 |
+
# "一" in number sequences, e.g. 一零零, 二一零
|
564 |
+
if word.find("一") != -1 and all(
|
565 |
+
[item.isnumeric() for item in word if item != "一"]
|
566 |
+
):
|
567 |
+
return finals
|
568 |
+
# "一" between reduplication words shold be yi5, e.g. 看一看
|
569 |
+
elif len(word) == 3 and word[1] == "一" and word[0] == word[-1]:
|
570 |
+
finals[1] = finals[1][:-1] + "5"
|
571 |
+
# when "一" is ordinal word, it should be yi1
|
572 |
+
elif word.startswith("第一"):
|
573 |
+
finals[1] = finals[1][:-1] + "1"
|
574 |
+
else:
|
575 |
+
for i, char in enumerate(word):
|
576 |
+
if char == "一" and i + 1 < len(word):
|
577 |
+
# "一" before tone4 should be yi2, e.g. 一段
|
578 |
+
if finals[i + 1][-1] == "4":
|
579 |
+
finals[i] = finals[i][:-1] + "2"
|
580 |
+
# "一" before non-tone4 should be yi4, e.g. 一天
|
581 |
+
else:
|
582 |
+
# "一" 后面如果是标点,还读一声
|
583 |
+
if word[i + 1] not in self.punc:
|
584 |
+
finals[i] = finals[i][:-1] + "4"
|
585 |
+
return finals
|
586 |
+
|
587 |
+
def _split_word(self, word: str) -> List[str]:
|
588 |
+
word_list = jieba.cut_for_search(word)
|
589 |
+
word_list = sorted(word_list, key=lambda i: len(i), reverse=False)
|
590 |
+
first_subword = word_list[0]
|
591 |
+
first_begin_idx = word.find(first_subword)
|
592 |
+
if first_begin_idx == 0:
|
593 |
+
second_subword = word[len(first_subword) :]
|
594 |
+
new_word_list = [first_subword, second_subword]
|
595 |
+
else:
|
596 |
+
second_subword = word[: -len(first_subword)]
|
597 |
+
new_word_list = [second_subword, first_subword]
|
598 |
+
return new_word_list
|
599 |
+
|
600 |
+
def _three_sandhi(self, word: str, finals: List[str]) -> List[str]:
|
601 |
+
if len(word) == 2 and self._all_tone_three(finals):
|
602 |
+
finals[0] = finals[0][:-1] + "2"
|
603 |
+
elif len(word) == 3:
|
604 |
+
word_list = self._split_word(word)
|
605 |
+
if self._all_tone_three(finals):
|
606 |
+
# disyllabic + monosyllabic, e.g. 蒙古/包
|
607 |
+
if len(word_list[0]) == 2:
|
608 |
+
finals[0] = finals[0][:-1] + "2"
|
609 |
+
finals[1] = finals[1][:-1] + "2"
|
610 |
+
# monosyllabic + disyllabic, e.g. 纸/老虎
|
611 |
+
elif len(word_list[0]) == 1:
|
612 |
+
finals[1] = finals[1][:-1] + "2"
|
613 |
+
else:
|
614 |
+
finals_list = [finals[: len(word_list[0])], finals[len(word_list[0]) :]]
|
615 |
+
if len(finals_list) == 2:
|
616 |
+
for i, sub in enumerate(finals_list):
|
617 |
+
# e.g. 所有/人
|
618 |
+
if self._all_tone_three(sub) and len(sub) == 2:
|
619 |
+
finals_list[i][0] = finals_list[i][0][:-1] + "2"
|
620 |
+
# e.g. 好/喜欢
|
621 |
+
elif (
|
622 |
+
i == 1
|
623 |
+
and not self._all_tone_three(sub)
|
624 |
+
and finals_list[i][0][-1] == "3"
|
625 |
+
and finals_list[0][-1][-1] == "3"
|
626 |
+
):
|
627 |
+
finals_list[0][-1] = finals_list[0][-1][:-1] + "2"
|
628 |
+
finals = sum(finals_list, [])
|
629 |
+
# split idiom into two words who's length is 2
|
630 |
+
elif len(word) == 4:
|
631 |
+
finals_list = [finals[:2], finals[2:]]
|
632 |
+
finals = []
|
633 |
+
for sub in finals_list:
|
634 |
+
if self._all_tone_three(sub):
|
635 |
+
sub[0] = sub[0][:-1] + "2"
|
636 |
+
finals += sub
|
637 |
+
|
638 |
+
return finals
|
639 |
+
|
640 |
+
def _all_tone_three(self, finals: List[str]) -> bool:
|
641 |
+
return all(x[-1] == "3" for x in finals)
|
642 |
+
|
643 |
+
# merge "不" and the word behind it
|
644 |
+
# if don't merge, "不" sometimes appears alone according to jieba, which may occur sandhi error
|
645 |
+
def _merge_bu(self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
|
646 |
+
new_seg = []
|
647 |
+
last_word = ""
|
648 |
+
for word, pos in seg:
|
649 |
+
if last_word == "不":
|
650 |
+
word = last_word + word
|
651 |
+
if word != "不":
|
652 |
+
new_seg.append((word, pos))
|
653 |
+
last_word = word[:]
|
654 |
+
if last_word == "不":
|
655 |
+
new_seg.append((last_word, "d"))
|
656 |
+
last_word = ""
|
657 |
+
return new_seg
|
658 |
+
|
659 |
+
# function 1: merge "一" and reduplication words in it's left and right, e.g. "听","一","听" ->"听一听"
|
660 |
+
# function 2: merge single "一" and the word behind it
|
661 |
+
# if don't merge, "一" sometimes appears alone according to jieba, which may occur sandhi error
|
662 |
+
# e.g.
|
663 |
+
# input seg: [('听', 'v'), ('一', 'm'), ('听', 'v')]
|
664 |
+
# output seg: [['听一听', 'v']]
|
665 |
+
def _merge_yi(self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
|
666 |
+
new_seg = []
|
667 |
+
# function 1
|
668 |
+
for i, (word, pos) in enumerate(seg):
|
669 |
+
if (
|
670 |
+
i - 1 >= 0
|
671 |
+
and word == "一"
|
672 |
+
and i + 1 < len(seg)
|
673 |
+
and seg[i - 1][0] == seg[i + 1][0]
|
674 |
+
and seg[i - 1][1] == "v"
|
675 |
+
and seg[i + 1][1] == "v"
|
676 |
+
):
|
677 |
+
new_seg[i - 1][0] = new_seg[i - 1][0] + "一" + new_seg[i - 1][0]
|
678 |
+
else:
|
679 |
+
if (
|
680 |
+
i - 2 >= 0
|
681 |
+
and seg[i - 1][0] == "一"
|
682 |
+
and seg[i - 2][0] == word
|
683 |
+
and pos == "v"
|
684 |
+
):
|
685 |
+
continue
|
686 |
+
else:
|
687 |
+
new_seg.append([word, pos])
|
688 |
+
seg = new_seg
|
689 |
+
new_seg = []
|
690 |
+
# function 2
|
691 |
+
for i, (word, pos) in enumerate(seg):
|
692 |
+
if new_seg and new_seg[-1][0] == "一":
|
693 |
+
new_seg[-1][0] = new_seg[-1][0] + word
|
694 |
+
else:
|
695 |
+
new_seg.append([word, pos])
|
696 |
+
return new_seg
|
697 |
+
|
698 |
+
# the first and the second words are all_tone_three
|
699 |
+
def _merge_continuous_three_tones(
|
700 |
+
self, seg: List[Tuple[str, str]]
|
701 |
+
) -> List[Tuple[str, str]]:
|
702 |
+
new_seg = []
|
703 |
+
sub_finals_list = [
|
704 |
+
lazy_pinyin(word, neutral_tone_with_five=True, style=Style.FINALS_TONE3)
|
705 |
+
for (word, pos) in seg
|
706 |
+
]
|
707 |
+
assert len(sub_finals_list) == len(seg)
|
708 |
+
merge_last = [False] * len(seg)
|
709 |
+
for i, (word, pos) in enumerate(seg):
|
710 |
+
if (
|
711 |
+
i - 1 >= 0
|
712 |
+
and self._all_tone_three(sub_finals_list[i - 1])
|
713 |
+
and self._all_tone_three(sub_finals_list[i])
|
714 |
+
and not merge_last[i - 1]
|
715 |
+
):
|
716 |
+
# if the last word is reduplication, not merge, because reduplication need to be _neural_sandhi
|
717 |
+
if (
|
718 |
+
not self._is_reduplication(seg[i - 1][0])
|
719 |
+
and len(seg[i - 1][0]) + len(seg[i][0]) <= 3
|
720 |
+
):
|
721 |
+
new_seg[-1][0] = new_seg[-1][0] + seg[i][0]
|
722 |
+
merge_last[i] = True
|
723 |
+
else:
|
724 |
+
new_seg.append([word, pos])
|
725 |
+
else:
|
726 |
+
new_seg.append([word, pos])
|
727 |
+
|
728 |
+
return new_seg
|
729 |
+
|
730 |
+
def _is_reduplication(self, word: str) -> bool:
|
731 |
+
return len(word) == 2 and word[0] == word[1]
|
732 |
+
|
733 |
+
# the last char of first word and the first char of second word is tone_three
|
734 |
+
def _merge_continuous_three_tones_2(
|
735 |
+
self, seg: List[Tuple[str, str]]
|
736 |
+
) -> List[Tuple[str, str]]:
|
737 |
+
new_seg = []
|
738 |
+
sub_finals_list = [
|
739 |
+
lazy_pinyin(word, neutral_tone_with_five=True, style=Style.FINALS_TONE3)
|
740 |
+
for (word, pos) in seg
|
741 |
+
]
|
742 |
+
assert len(sub_finals_list) == len(seg)
|
743 |
+
merge_last = [False] * len(seg)
|
744 |
+
for i, (word, pos) in enumerate(seg):
|
745 |
+
if (
|
746 |
+
i - 1 >= 0
|
747 |
+
and sub_finals_list[i - 1][-1][-1] == "3"
|
748 |
+
and sub_finals_list[i][0][-1] == "3"
|
749 |
+
and not merge_last[i - 1]
|
750 |
+
):
|
751 |
+
# if the last word is reduplication, not merge, because reduplication need to be _neural_sandhi
|
752 |
+
if (
|
753 |
+
not self._is_reduplication(seg[i - 1][0])
|
754 |
+
and len(seg[i - 1][0]) + len(seg[i][0]) <= 3
|
755 |
+
):
|
756 |
+
new_seg[-1][0] = new_seg[-1][0] + seg[i][0]
|
757 |
+
merge_last[i] = True
|
758 |
+
else:
|
759 |
+
new_seg.append([word, pos])
|
760 |
+
else:
|
761 |
+
new_seg.append([word, pos])
|
762 |
+
return new_seg
|
763 |
+
|
764 |
+
def _merge_er(self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
|
765 |
+
new_seg = []
|
766 |
+
for i, (word, pos) in enumerate(seg):
|
767 |
+
if i - 1 >= 0 and word == "儿" and seg[i - 1][0] != "#":
|
768 |
+
new_seg[-1][0] = new_seg[-1][0] + seg[i][0]
|
769 |
+
else:
|
770 |
+
new_seg.append([word, pos])
|
771 |
+
return new_seg
|
772 |
+
|
773 |
+
def _merge_reduplication(self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
|
774 |
+
new_seg = []
|
775 |
+
for i, (word, pos) in enumerate(seg):
|
776 |
+
if new_seg and word == new_seg[-1][0]:
|
777 |
+
new_seg[-1][0] = new_seg[-1][0] + seg[i][0]
|
778 |
+
else:
|
779 |
+
new_seg.append([word, pos])
|
780 |
+
return new_seg
|
781 |
+
|
782 |
+
def pre_merge_for_modify(self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
|
783 |
+
seg = self._merge_bu(seg)
|
784 |
+
try:
|
785 |
+
seg = self._merge_yi(seg)
|
786 |
+
except:
|
787 |
+
print("_merge_yi failed")
|
788 |
+
seg = self._merge_reduplication(seg)
|
789 |
+
try:
|
790 |
+
seg = self._merge_continuous_three_tones(seg)
|
791 |
+
except:
|
792 |
+
print("_merge_continuous_three_tones failed")
|
793 |
+
try:
|
794 |
+
seg = self._merge_continuous_three_tones_2(seg)
|
795 |
+
except:
|
796 |
+
print("_merge_continuous_three_tones_2 failed")
|
797 |
+
|
798 |
+
seg = self._merge_er(seg)
|
799 |
+
return seg
|
800 |
+
|
801 |
+
def modified_tone(self, word: str, pos: str, finals: List[str]) -> List[str]:
|
802 |
+
finals = self._bu_sandhi(word, finals)
|
803 |
+
finals = self._yi_sandhi(word, finals)
|
804 |
+
finals = self._neural_sandhi(word, pos, finals)
|
805 |
+
finals = self._three_sandhi(word, finals)
|
806 |
+
return finals
|
SongBloom/g2p/cn_zh_g2p/zh_normalization/README.md
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## Supported NSW (Non-Standard-Word) Normalization
|
2 |
+
|
3 |
+
|NSW type|raw|normalized|
|
4 |
+
|:--|:-|:-|
|
5 |
+
|serial number|电影中梁朝伟扮演的陈永仁的编号27149|电影中梁朝伟扮演的陈永仁的编号二七一四九|
|
6 |
+
|cardinal|这块黄金重达324.75克<br>我们班的最高总分为583分|这块黄金重达三百二十四点七五克<br>我们班的最高总分为五百八十三分|
|
7 |
+
|numeric range |12\~23<br>-1.5\~2|十二到二十三<br>负一点五到二|
|
8 |
+
|date|她出生于86年8月18日,她弟弟出生于1995年3月1日|她出生于八六年八月十八日, 她弟弟出生于一九九五年三月一日|
|
9 |
+
|time|等会请在12:05请通知我|等会请在十二点零五分请通知我
|
10 |
+
|temperature|今天的最低气温达到-10°C|今天的最低气温达到零下十度
|
11 |
+
|fraction|现场有7/12的观众投出了赞成票|现场有十二分之七的观众投出了赞成票|
|
12 |
+
|percentage|明天有62%的概率降雨|明天有百分之六十二的概率降雨|
|
13 |
+
|money|随便来几个价格12块5,34.5元,20.1万|随便来几个价格十二块五,三十四点五元,二十点一万|
|
14 |
+
|telephone|这是固话0421-33441122<br>这是手机+86 18544139121|这是固话零四二一三三四四一一二二<br>这是手机八六一八五四四一三九一二一|
|
15 |
+
## References
|
16 |
+
[Pull requests #658 of DeepSpeech](https://github.com/PaddlePaddle/DeepSpeech/pull/658/files)
|
SongBloom/g2p/cn_zh_g2p/zh_normalization/__init__.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
from .text_normlization import *
|
SongBloom/g2p/cn_zh_g2p/zh_normalization/char_convert.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
"""Traditional and simplified Chinese conversion, a simplified character may correspond to multiple traditional characters.
|
16 |
+
"""
|
17 |
+
simplified_charcters = '制咖片型超声盘鉴定仔点他命书歌粉巾字帐恤手指记忆棒形转弯沟光○〇㐄㐅㐆㐌㐖毒㐜㐡㐤㐰㐺㑇㑳㒳㒸㔾㗂㗎㝵㞎㞙㞞以㢲㢴㤅㥁㥯㨗㫺㬎㮎㮚㮸㲋㲱㲾㳮涧㵪㶸㷖㷭㹢㹴犬㺢狓㺵碗㽮㿝䍃䔢䖟䖸䗈䗥䗪䝓射䥯䦉䯝鲃鱼䲔䳗鹅䵹鼄䶑一对应映射丁不识下儿子做二休世丘之貉并中台原则串为甚谓干净了百事无成八变五十些人得道鸡升天代如并来去个国政策劲幽灵在欧洲游荡接样萝卜坑侧化传价元论醇共再准刀两断切分耕耘收获钱货物向看旧就绪险刻千金动劳永逸匙零夜半卡通回复返影踪反常态口咬气句话同吐快吹周味呼诺呜品红锅哄而散起唱和问三知生熟团漆黑火糟堆场空块面塌糊涂尘染壁厢夔已足多情露水大早到晚夫妻当关万莫开失古恨套所料既往孔见提师要家主审寸阴难买斗牛小撮部阵局展身层巴掌帆风顺席地带过年计于春头载四季期被蛇怕井绳度愿式份弹顷深前律径心意念差愁孤行俱全房厅交遮打技长把抓死拿眼泪鼻涕钥锁折段抿拍即合扫排掬挥拨拥上入击洞掷揽改故辙败文值名斑方面旁族日秋餐隔雅里终父旦时晌会霎间晃暴寒曝更月望垠际朝夕本正经利杯羹东西板枝独秀根筋杆进条龙服务概模次函数又性程总付步脚印趋登毛拔呵氧氮碳决雌雄波未平派谎言流清楚白准溜烟潭有获闻是处降琴鹤甲病发可拾沙目然了直以相眨穿睹瞥瞬矢的解石鸟神教秉虔诚秘种窝蜂穷窍笑置笔苟勾销抹杀煞等奖箍节吃箭仇双雕诗筹箩筐系列纸级士官统丝毫挂维网尽线微吭响股脑胎脉承腔臂力致效资源址器举功投般说讲规贸易叶障着慎满皆输号木电池衣倾钟高低视仁觉醒览遗角银币触溃九鼎蔽抄出驷马追重语破贫洗贯走路安蹴至几蹶振跃役胆汗较辈轮辞赞退六连遍递边针血锤音错门思闪真倒项栽雾类保护川先惊乍体哄鳞爪鸣滴泡邻域党专鼓作齐炒丑烯亥克内酯冬加奴卯肝炎基尺梁街裤镐客宠庭巳汝昌烷玲磊糖肇酉醛啷青县韪良香骨鲷丂七集河市弦喜嘴张舌堵区工业姊妹星架构巧彩扭歪拼凑余热曜武州爷浮屠美乡老阶树荤素碎落能魄鳃鳗珠丄丅丆万俟丈尚摸母娘量管群亚虎必我堂令申件装伏位博侠义界表女墟台戏臭皮匠胜诸葛亮赛顶倍催请运算包立叉戟离疫苗土史志演围揭瓦晒夷姑婆帝村宝烂尖杉碱屉桌山岔岛由纪峡坝库镇废从德后拗汤治旬食明昧曹朋友框栏极权幂曲归依猫民氟硼氯磷铁江侗自旅法司洋浦梅园温暖湾焦班幸用田略番叠皇炮捶硝苯酸腺苷棱草镜穗跳远索锦纲聚氰胺联店胚膲爱色堇紫罗兰芝茶饭菱云虫藏藩乱叛苏亲债凳学座恐恋柱测肌腹衩锥系貂企乌跪叩军车农题迭都甘油屯奏键短阿姨陪姐只顾茅庐槽驾魂鲜鹿页其菜单乘任供势午齿汉组织吊调泻唇坡城报坟外夸将尉建筑岸岗公床扬新剑升杭林栗校楼标款汽社浣海商馆剧院钢华港机械广媒环球融第医科证券综财乐育游涨犹岭疏瘾睑确兵领导缴肢膛船艾瑟尔苍蔡虞效衫覆访诉课谕议轨述野钩限敌鞋颌颔颚饶首龈站例修凡划垂届属崽颏厨拜挫摆放旋削棋榻槛礼沉注滑营狱画确仪聘花葬诏员跌辖周达酒锚闸陷陆雨雪飞威丌于丹久乏予理评产亢卑亦乎舞己悲矩圆词害志但住佞佳便俗信票案幅翁倦伦假偏倚斜亏鬼敲停备伤脾胃仅此像俭匮免宜穴焉戴兼容许冻伯仲负彼昼皂轩轾实刊划颠卫战哥比省非好黄饰别拘束掩奶睬选择摇扰烦苦枚写协厌及格受欢迎约只估侵犯割状告或缺抗拒挽撤救药喻磨灭端倪少逆逾越避靠适吉誉吝玉含延咎歹听啻渊善谋均匀堪忍够太惹妙妥妨孕症孝术室完纳推冠积宣疑辩栗碴称屈挠屑干涉衡待很忙恶忿怎么怠急耻恭息悦惑惜惟想愉愧怍慌愤启懂懈怀材才紧招认扣抵拉舍也罢插揣冒搭撞南墙扩核支攻敢雷攀敬里吗需景智暇曾罪遇朽枉止况竞争辱求愈渝溶济左右袒困补爽特寂寞示弱找谢畏强疾徐痛痒冤符眠睦瞅董何厚云措活疲羞者轻玻璃祥兆禁���稂莠稳佛换答简结果盟绝缕途给谈否羁翼耐肖胫毋宁兴舒若菲莱痕迹窠臼虚衰脸兔撒鹰棺范该详讳抬泰让须眉象众赀账费灰赖奇虑训辍辨菽麦辛近送透逞徒速续逮捕遂遑违逊斧钺艰醉锈随观弃显饱脂肪使丏丐帮丒且慢末丕替桃宗王尊凉爵各图屋脊粮署录坛吾禄职胄袭君厦丗北壑桐疹损逢陵鹬丙寅戌氨腈唑纶辰酮脱氢酶醚丞丢现掉纱帽弄扯炮碗丠両丣坐存激肩臻蒂莲悖序驱丨丩丫挺杈髻鬟细介俄伊犁京尼布订普渡央委监察检查剂圈设警队斯督剩震境航舶革防托播促质版蝾螈锋研艺历残消频谱精密制造陲邮候埔坚压坜凹汇执府究邦俘摄寮彬狼岳肺肿庸英讯诊埋粒胞括控码韩暑枪枢砥澳哇牟寿甸钻探篇签缀缝继耳肯照妇埃悬璧轴柜台辣搁浅邪跑纤阮阳私囊魔丮丰姿采丱烧丳丵丶丷丸参寨朗桂瑞砂衷霞貌凤仆舰因嫌宰峰干络牌持旨祭祷簿编罚宾办丼丿乀乂乃乄仰慕盛旷留考验阔乆乇么丑麽乊湖燃乑乒乓乕乖僻忤戾离谬迕乗危肥劫除隙浪婿乙炔肠酰吡咯盐乚乛乜嘢卿玄宫尾狐龟塔嶷兄弟泉章霄钉耙乞扎哀怜恕讨乢乣乤乥乧乨乩童乪乫乭乳晕汁液瑶浆牙癌突窦罩腐胶猪酪蛋糕菌瘤乴乵乶乷乸乹乺乼乾俸冰嘉哕嚎坤妈尸垒旱枯涸俐渴潮涩煸豆燥爹瘦瘪癣瞪袋脆姜贝隆馏乿亀亁叫咕攘扔搞男砸窜蓬麻亃亄亅却亇迟典今临繁累卵奉婚聪躬巨与迁添裂副宿岁怪恶尕仑愣杆硅硫钛铀锰芑杂异钠砷胂磺琥珀舱棍簧胡茬盗浩盆贩郎腿亍洪亐互欠助勉惠操斥诿系户译亓墓碑刑铃卅渠缤纷斗米旗宪钒灯徽瘟祖拳福谷丰脏腑绑肉腌苓蕴桥铺霸颜闹判喷冈底蛙陉矿亖亘亜罕们娜桑那努哈喀弗烈曼松森杜氏杯奥琛敦戊穆圣裔汇薛孙亟亡佚虏羊牢奋释卷卸契媾感额睫缠谊趾塞挤纽阻还配驰庄亨洛祚亪享津沪畿郊慈菴枇杷膏亭阁锃丽亳亶亹诛初责翻疯偶杰丛稠妖拖寰居吸授慧蜗吞壮魅狗矛盾益渣患忧稀描猿梦暂涯畜祸缘沸搜引擎臣横纭谁混援蒸兽狮税剖亻亼亽亡什献刹邡么仂仃仄仆富怨仈仉毕昔晨壳绍仍仏仒仕宦仗欺恃腰叹叹炬梓讫施仙后琼逝仚仝仞仟悔仡佬偿填泊拓扑簇羔购顿钦佩发棻阃驭养亿儆尤借帧赈凌叙帖李柔刚沃眦睚戒讹取飨读仨仫仮著泳卧躺韶夏裁仳仵唯贤凭钓诞仿似宋佛讽伀硕盼鹅伄儅伈伉俪柯始娃迈戈坦堡帕茨萨庙玛莉莎藤霍姆伋伍奢胥廷芳豪伎俩侍汛勒希羲雏伐憩整谟闲闲伕伙伴颐伜伝伢叔恒兹恩翰伱伲侣伶俜悧鼬伸懒缩喇叭伹伺伻伽倻辐伾似佃伫布乔妮墨佉卢佌贷劣廉昂档浓矮伞洼缓耗胸谷迷挡率龋宅沫舍疗佐贰佑占优据铧尝呢须鲁晓佗佘余坪寺瓜铳僧蒙芒陀龛哼呕坊奸孽弊揖祟茧缚誓贼佝偻瞀佟你夺赶佡佢佣佤佧贾佪佫佯佰佱洁绩酿肴佴卷佶佷佸佹佺佻佼佽佾具唤窘坏娱怒慨硬习惯聋膨胀蔓骇贵痹侀侁侂侃侄侅鸿燕侇侈糜靡侉侌妾侏儒仓鼠侐侑侔仑侘侚链侜偎傍钴循柳葫芦附価侮骂蔑侯岩截蚀局贴壶嬛宴捷携桶笺酌俣狭膝狄俅俉俊俏俎俑俓俔谚俚俛黎健呈固墒增守康箱湿祐镖镳杠盒靖膜龄俞豹猎噪孚封札筒托衍鸽剪撰稿炼厂禊练缮葺俯瞰撑冲效俳俴俵俶俷俺备俾伥倂倅储卒惶敷猝逃颉蓄崇隐倌倏忽刺蜡烛噍嚼坍扁抽毙葱楣灌灶粪背薮卖赔闭霉腾倓倔幸倘倜傥倝借箸挹浇阅倡狂倢倣値倥偬倨傲倩匡嗣冲柝珍倬倭寇猩倮倶倷倹勤赞偁偃充伪吏嗓寐惺扮拱芫茜藉虢钞偈伟晶偌宕距析滤殿疼瘫注颇偓偕鸭歇滞偝偟偢忘怡旺偨偩逼偫偭偯偰偱偲侦缉蹄偷减惰漏窥窃偸偺迹傀儡傅傈僳骂篱傎奎琳迪叟芭傒傔傕伧悉荒傜傞傢傣芽逼佣婢傮睨寄檄诵谣颂伛担辜弓惨蒿悼疤傺傻屄臆巢泄箧羡盖轧颓傿㑩僄僇佥僊働僎侨僔僖僚僝伪僣僤侥僦猴偾僩僬僭僮僯僰雇僵殖签静僾僿征陇儁侬儃儇侩朴薄儊儋儌儍傧儓俦侪拟尽儜儞儤儦儩汰哉寡渥裕酷儭儱罐儳儵儹傩俨儽兀臬臲鹫允勋勋宙宵帅憝彝谐嫂阋畅沛溢盈饥赫凶悍狠猛顽愚妣斩秦遣鞭耀敏荣槃泽爆碟磁秃缆辉霁卤朵娄孜烽酱勃汀箕裘钳耶蒙蕾彻兑软遭黜兎児韵媳爸兕觥兖兙兛兜售鍪肚兝兞兟兡兢兣樽殓涅睡禀籍赘泌啡肽奸幕涵涝熵疚眷稃衬讧赴焕椒歼植跏没试误猜栖窗肋袖颊兪卦撇胡岐廓轿疸枫茴珑厕秩募勺吨寓斤历亩迫筷厘最淫螺韬兮宽匪筛襄赢轭复兲诈刃堰戎痞蚁饷它冀铸冂冃円冇冉册嫁厉砺竭醮冏牧冑冓冔冕冖冗冘冞冢窄抑诬冥冫烘菇蛰冷凝坨橇淇淋炭饼砖碛窖醋雕雹霜冱冶炉艳嘲峻滩淡漠煖飕饮冼冽凃凄怆梗凅凇净凊凋敝蒙凔凛遵汞脢凞几凢処凰凯凵凶焰凸折刷纹预丧喽奔巡榜殡芙蓉租笼辑鞘萃凼锯镬刁蛮刂娩崩批拆摊掰蘖骤歧颗秒袂赃勿嘱忌磋琢肤刈羽刎讼戮舂桨艇刓刖霹雳刜创犊刡恙墅帜筵致劫劫刨昏默攸尿欲熏润薰圭删刮痧铲刱刲刳刴刵踏磅戳柏槐绣芹苋猬舟铭鹄鹜劫剁剃辫刭锉履铅克剌姻咽哨廊掠桅沿召瞻翅赵卜渺茫郭剒剔剕沥剚愎毅讷才剜剥啄采剞剟剡剣剤䌽剐肾驶黏剰袍剀紊铲剸剺剽剿劁劂札劈啪柴扳啦刘奭姥夼昫涓熙禅禹锡翔雁鹗刽刿弩柄蜻蛉劒劓劖劘劙澜篑赏矶釜晋甜薪逐劦熔纣虐赤囚劬劭労劵效劻劼劾峭艮勅勇励勍勐腊脖庞漫饲荡粥辄勖勗勘骄馁碌泮雇捐竹骑殊阱绩朴恳谨剿勧勩勯勰劢勋勷劝惩慰诫谏勹芡践阑匁庇拯粟扎袱裹饺匆遽匈匉匊匋匍匐茎匏匕妆痰脓蛹斋苑烤蹈塘羌熊阀螳螂疆碚竿纬荷茵邙魏匚匜匝匟扶稷匣匦拢匸匹耦匽匾匿卂叮疮禧轸堤棚迢钧炼卄卆遐卉瓷盲瓶当胱腱裸卋卌卍卐怯污贱鄙龌龊陋卓溪唐梯渔陈枣泥漳浔涧梨芬谯赡辕迦郑単驴弈洽鳌卛占筮卝卞卟吩啉屎翠厄卣卨卪卬卮榫袄玺绶钮蚤惧殆笃耸卲帘帙绕恤卼卽厂厎厓厔厖厗奚厘厍厜厝谅厕厤厥厪腻孢厮厰厳厣厹厺粕垢芜菁厼厾叁悟茸薯叄吵笄悌哺讥坫垄弧芯杠潜婴刍袁诘贪谍煽馈驳収岳缔灾贿骗叚叡吻拦蘑蜜诀燧玩砚筝椎蔺铜逗骊另觅叨唠谒杵姓喊嚷嚣咚咛塑寻恼憎擦只泣渗蝠叱吒咄咤喝籀黛舵舷叵叶铎懿昭穰苴辽叻叼吁堑嫖赌瞧爬众抒吅吆夥卺橡涤抱纵摩郡唁坠扇篮膀袜颈吋忾谘酬哭妓媛暗表缰迩妃羿絮蕃浑拐葵暮隅吔吖啶嗪戚吜啬噬咽吟哦咏吠吧唧嗒咐吪隽咀征燐苞茹钙哧吮吰吱嘎吲哚吴栋娇窟孟箫忠晗淞阖闾趼宇呐睛嘘拂捧疵熄竽笛糠吼吽呀吕韦蒙呃呆笨呇贡呉罄呋喃呎呏呔呠呡痴呣呤呦呧瑛眩扒晬淑姬瑜璇鹃呪呫哔嚅嗫呬呯呰呱呲咧噌钝呴呶呷呸呺呻哱咻啸噜吁坎坷逻呿咁咂咆哮咇咈咋蟹煦珅蔼咍咑咒诅咔哒嚓咾哝哩喱咗咠咡咢咣咥咦咨嗟询咩咪咫啮啮咭咮咱咲咳呛嗽咴啕咸咹咺呙喉咿婉恸悯赋矜绿茗蓝哂抢瞒哆嗦啰噻啾滨彗哋哌哎唷哟哏哐哞哢哤哪里哫啼喘哰哲萎蚌哳咩哽哿呗唅唆唈唉唎唏哗尧棣殇璜睿肃唔睇唕吣唞唣喳唪唬唰喏唲唳唵嘛唶唸唹唻唼唾唿啁啃鹦鹉啅埠栈榷祺铺鞅飙啊啍啎啐啓啕啖啗啜哑祈啢衔啤啥啫啱啲啵啺饥啽噶昆沁喁喂喆裙喈咙喋喌喎喑喒喓喔粗喙幛庆滋鹊喟喣喤喥喦喧骚喨喩梆吃葡萄喭驼挑吓碰枞瓣纯疱藻趟铬喵営喹喺喼喿嗀嗃嗄嗅嗈嗉嗊嗍嗐嗑嗔诟嗕嗖嗙嗛嗜痂癖嗝嗡嗤嗥嗨唢嗬嗯嗰嗲嗵叽嗷嗹嗾嗿嘀嘁嘂嘅惋嘈峪禾荫啀嘌嘏嘐嘒啯啧嘚唛嘞嘟囔嘣嘥嘦嘧嘬嘭这谑严敞馋松哓嘶嗥呒虾嘹嘻啴嘿噀噂噅噇噉噎噏噔噗噘噙噚咝噞噢噤蝉皿噩噫噭嗳噱哙噳嚏涌洒欲巫霏噷噼嚃嚄嚆抖哜尝嚔苏嚚嚜嚞嚟呖嚬嚭嚮嚯亸喾饬按竣苛嚵嘤啭冁呓膪谦囍囒囓囗囘萧酚飘溅谛囝溯眸纥銮鹘囟殉囡団囤囥囧囨囱囫囵囬囮囯囲図囶囷囸囹圄圉拟囻囿圀圂圃圊粹蠹赦圌垦圏滚鲱凿枘圕圛圜圞坯埂壤骸炕祠窑豚绅魠鲮鳖圧握圩圪垯圬圮圯炸岬幔毯祇窨菩溉圳圴圻圾坂坆沾坋坌舛壈昆垫墩椅坒坓坩埚坭坰坱坳坴坵坻坼杨挣涎帘垃垈垌垍垓垔垕垗垚垛垝垣垞垟垤垧垮垵垺垾垿埀畔埄埆埇埈埌殃隍埏埒埕埗埜垭埤埦埧埭埯埰埲埳埴埵埶绋埸培怖桩础辅埼埽堀诃侄庑堃堄摧磐贞韧砌堈堉垩堋堌堍堎垴堙堞堠礁堧堨舆堭堮蜓摘堲堳堽堿塁塄塈煤茔棵塍垲埘塓绸塕鸦沽虱塙冢塝缪塡坞埙塥塩塬塱场螨塼塽塾塿墀墁墈墉墐夯増毁墝墠墦渍钵墫墬堕墰墺墙橱壅壆壊壌壎壒榨蒜壔壕壖圹垆壜壝垅壡壬壭壱売壴壹壻壸寝壿夂夅夆変夊夌漱邑夓腕泄甥御骼夗夘夙衮瑙妊娠醣枭珊莺鹭戗幻魇夤蹀秘擂鸫姚宛闺屿庾挞拇賛蛤裨菠氅漓捞湄蚊霆鲨箐篆篷荆肆舅荔鲆巷惭骰辟邱镕镰阪漂烩鲵鲽鳄鸨胪鹏妒峨谭枰晏玑癸祝秤竺牡籁恢罡蝼蝎赐绒御梭夬夭砣榆怙枕夶夹馅奄崛葩谲奈贺祀赠奌奂奓奕䜣詝奘奜奠奡奣陶奨奁魁奫奬奰娲孩贬隶酥宄狡猾她姹嫣妁毡荼皋膻蝇嫔妄妍嫉媚娆妗趣妚妞妤碍妬娅妯娌妲妳妵妺姁姅姉姗姒姘姙姜姝姞姣姤姧姫姮娥姱姸姺姽婀娀诱慑胁娉婷娑娓娟娣娭娯娵娶娸娼婊婐婕婞婤婥溪孺婧婪婬婹婺婼婽媁媄媊媕媞媟媠媢媬媮妫媲媵媸媺媻媪眯媿嫄嫈袅嫏嫕妪嫘嫚嫜嫠嫡嫦嫩嫪毐嫫嫬嫰妩嫺娴嫽嫿妫嬃嬅嬉耍婵痴艳嬔嬖嬗嫱袅嫒嬢嬷嬦嬬嬭幼嬲嬴婶嬹嬾嬿孀娘孅娈孏曰癫屏孑孓雀孖斟篓谜摺孛矻鸠崮轲祜鸾孥邈毓棠膑孬孭孰孱孳孵泛罔衔孻孪宀宁冗拙株薇掣抚琪瓿榴谧弥宊濂祁瑕宍宏碁宓邸谳実潢町宥宧宨宬徵崎骏掖阙臊煮禽蚕宸豫寀寁寥寃檐庶寎暄碜寔寖寘寙寛寠苫寤肘洱滥蒗陕核寪弘绰螽宝擅疙瘩晷対檐専尃尅赎绌缭畴衅尌峙醌襟痲碧屁昊槌淘恵瀑牝畑莓缸羚觑蔻脏躁尔尓锐尗尙尜尟尢��尨尪尬尭尰擒尲尶尴尸尹潽蠖蛾尻扣梢蚴鳍脬蹲屇屌蚵屐屃挪屖屘屙屛屝屡屣峦嶂岩舄屧屦屩屪屃屮戍驻钾崖嵛巅旮旯楂榄榉芋茱萸靛麓屴屹屺屼岀岊岌岍阜岑彭巩岒岝岢岚岣岧岨岫岱岵岷峁峇峋峒峓峞峠嵋峨峰峱岘峹峿崀崁崆祯崋崌崃岖昆崒崔嵬巍萤颢崚崞崟崠峥巆崤崦崧殂岽崱崳崴崶崿嵂嵇嵊泗嵌嵎嵒嵓岁嵙嵞嵡嵩嵫嵯嵴嵼嵾嵝崭崭晴嶋嶌嶒嶓嵚崂嶙嶝嶞峤嶡嶢峄嶨嶭嶮嶰嶲岙嵘巂巃巇巉岿巌巓巘巛滇芎巟巠弋回巣巤炊擘蜥蟒蛊觋巰蜀彦淖杏茂甫楞巻巽帼巿帛斐鲫蕊帑帔帗帚琉汶帟帡帣帨裙帯帰帷帹暆帏幄帮幋幌幏帻幙帮幞幠幡幢幦幨幩幪帱幭幯幰遥蹉跎馀庚鉴幵幷稚邃庀庁広庄庈庉笠庋跋庖牺庠庤庥鲸庬庱庳庴庵馨衢庹庿廃厩廆廋廌廎廏廐廑廒荫廖廛厮搏锣廞弛袤廥廧廨廪廱绵踵髓廸迫瓯邺廻廼廾廿躔弁皱弇弌弍弎弐弑吊诡憾荐弝弢弣弤弨弭弮弰弪霖繇焘斌旭溥骞弶弸弼弾彀彄别累纠强彔彖彘彟彟陌彤贻彧绘虹彪炳雕蔚鸥彰瘅彲彳彴仿彷徉徨彸彽踩敛旆徂徇徊渭畲铉裼従筌徘徙徜徕膳苏萌渐徬徭醺徯徳徴潘徻徼忀瘁胖燎怦悸颤扉犀澎湃砰恍惚绞隘忉惮挨饿忐忑忒忖応忝忞耿忡忪忭忮忱忸怩忻悠懑怏遏怔怗怚怛怞怼黍讶怫怭懦怱怲恍怵惕怸怹恁恂恇恉恌恏恒恓恔恘恚恛恝恞恟恠恣恧眄恪恫恬澹恰恿悀悁悃悄悆悊悐悒晦悚悛悜悝悤您悩悪悮悰悱凄恻德悴怅惘闷悻悾惄愫钟蒐惆惇惌惎惏惓惔惙惛耄惝疟浊恿惦德恽惴蠢惸拈愀愃愆愈愊愍愐愑愒愓愔愕恪氓蠢騃昵惬赧悫愬愮愯恺愼慁恿慅慆慇霭慉慊愠慝慥怄怂慬慱悭慴慵慷戚焚憀灼郁憃惫憋憍眺捏轼愦憔憖憙憧憬憨憪憭怃憯憷憸憹憺懃懅懆邀懊懋怿懔懐懞懠懤懥恹懫懮懰懱毖懵遁梁雍忏懽戁戄戆戉戋戕戛戝戛戠戡戢戣戤戥戦戬戭戯轰戱披菊牖戸戹戺戻卯戽锹扂楔扃扆扈扊杖牵绢铐镯赉扐搂搅烊盹瞌跟趸镲靶鼾払扗玫腮扛扞扠扡扢盔押扤扦扱罾揄绥鞍郤窾扻扼扽抃抆抈抉抌抏瞎抔缳缢擞抜拗択抨摔歉蹿牾抶抻搐泵菸拃拄拊髀抛拌脯拎拏拑擢秧沓曳挛迂拚拝拠拡拫拭拮踢拴拶拷攒拽掇芥橐簪摹疔挈瓢骥捺蹻挌挍挎挐拣挓挖掘浚挙揍聩挲挶挟挿捂捃捄捅捆捉捋胳膊揎捌捍捎躯蛛捗捘捙捜捥捩扪捭据捱捻捼捽掀掂抡臀膘掊掎掏掐笙掔掗掞棉芍掤搪阐掫掮掯揉掱掲掽掾揃揅揆搓揌诨揕揗揘揜揝揞揠揥揩揪揫橥遒麈揰揲揵揶揸背揺搆搉搊搋搌搎搔搕撼橹捣搘搠搡搢搣搤搥搦搧搨搬楦裢讪赸掏搰搲搳搴揾搷搽搾搿摀摁摂摃摎掴摒摓跤摙摛掼摞摠摦喉羯摭摮挚摰摲抠摴抟摷掺摽撂撃撅稻撊撋挦锏泼撕撙撚㧑挢撢掸撦撅撩撬撱朔揿蚍蜉挝捡擀掳闯擉缶觚擐擕擖擗擡擣擤澡腚擧擨擩擫擭摈拧撷擸撸擽擿攃摅撵攉攥攐攓撄搀撺每攩攫辔澄攮攰攲攴轶攷砭讦攽碘敁敃敇敉叙敎筏敔敕敖闰诲敜煌敧敪敳敹敺敻敿斁衽斄牒绉诌斉斎斓鹑谰驳鳢斒筲斛斝斞斠斡斢斨斫斮晾沂潟颖绛邵斲斸釳於琅斾斿旀旗旃旄涡旌旎旐旒旓旖旛旝旟旡旣浴旰獭魃旴时旻旼旽昀昃昄昇昉晰躲澈熹皎皓矾昑昕昜昝昞昡昤晖笋昦昨是昱昳昴昶昺昻晁蹇隧蔬髦晄晅晒晛晜晞晟晡晢晤晥曦晩萘莹顗晿暁暋暌暍暐暔暕煅旸暝暠暡曚暦暨暪朦胧昵暲殄冯暵暸暹暻暾曀晔昙曈曌曏曐暧曘曙曛叠昽曩骆曱甴肱曷牍禺锟曽沧耽朁朅朆杪栓夸竟粘绦朊膺朏朐朓朕朘朙瞄觐溘饔飧朠朢朣栅椆淀虱朩朮朰朱炆璋钰炽鹮朳槿朵朾朿杅杇杌陧欣钊湛漼楷瀍煜玟缨翱肇舜贽适逵杓杕杗杙荀蘅杝杞脩珓筊杰榔狍閦颦缅莞杲杳眇杴杶杸杻杼枋枌枒枓衾葄翘纾逋枙狸桠枟槁枲枳枴枵枷枸橼枹枻柁柂柃柅柈柊柎某柑橘柒柘柙柚柜柞栎柟柢柣柤柩柬柮柰柲橙柶柷柸柺査柿栃栄栒栔栘栝栟柏栩栫栭栱栲栳栴檀栵栻桀骜桁镁桄桉桋桎梏椹葚桓桔桕桜桟桫椤桭杯桯桲桴桷桹湘溟梃梊梍梐潼栀枧梜梠梡梣梧梩梱梲梳梴梵梹棁棃樱棐棑棕榈簑绷蓑枨棘棜棨棩棪棫棬棯棰棱棳棸棹椁棼碗椄苕椈椊椋椌椐椑椓椗検椤椪椰椳椴椵椷椸椽椿楀匾楅篪楋楍楎楗楘楙楛楝楟楠楢楥桢楩楪楫楬楮楯楰梅楸楹楻楽榀榃榊榎槺榕榖榘榛狉莽搒笞榠榡榤榥榦榧杩榭榰榱梿霰榼榾桤槊闩槎槑槔槖様槜槢槥椠槪槭椮槱槲槻槼槾樆樊樏樑樕樗樘樛樟樠樧樨権樲樴樵猢狲桦樻罍樾樿橁橄橆桡笥龠橕橚橛辆椭橤橧竖膈跨橾橿檩檃檇柽檍檎檑檖檗桧槚檠樯檨檫檬梼槟檴檵柠棹櫆櫌栉櫜椟櫡槠栌枥榇栊櫹棂茄櫽欀欂欃欐欑栾欙棂溴欨欬欱欵欶欷歔欸欹欻欼欿歁歃歆艎歈歊莳蝶歓歕歘歙歛歜欤歠蹦诠镶蹒跚升陟歩歮歯歰歳歴璞歺瞑歾殁夭殈殍殑殗殜殙殛殒殢殣殥殪殚僵殰殳荃殷殸殹蛟殻肴谤殴毈毉喂毎���蕈毗毘毚茛邓毧毬毳毷毹毽毾毵牦氄氆靴氉氊氇氍氐聊氕氖気氘氙氚氛氜氝氡汹焊痉氤氲氥氦铝锌氪烃氩铵痤汪浒漉痘盂碾菖蒲蕹蛭螅氵冰氹氺氽烫氾氿渚汆汊汋汍汎汏汐汔汕褟汙汚汜蓠沼秽蔑汧汨汩汭汲汳汴堤汾沄沅沆瀣沇沈葆浸沦湎溺痼疴沌沍沏沐沔沕沘浜畹砾沚沢沬沭沮沰沱灢沴沷籽沺烹濡洄泂肛泅泆涌肓泐泑泒泓泔泖泙泚泜泝泠漩馍涛粼泞藓鳅泩泫泭泯铢泱泲洇洊泾琵琶荽蓟箔洌洎洏洑潄濯洙洚洟洢洣洧洨洩痢滔洫洮洳洴洵洸洹洺洼洿淌蜚浄浉浙赣渫浠浡浤浥淼瀚浬浭翩萍浯浰蜃淀苔蛞蝓蜇螵蛸煲鲤浃浼浽溦涂涊涐涑涒涔滂莅涘涙涪涫涬涮涴涶涷涿淄淅淆淊凄黯淓淙涟淜淝淟淠淢淤渌淦淩猥藿亵淬淮淯淰淳诣涞纺淸淹炖癯绮渇済渉渋渓渕涣渟渢滓渤澥渧渨渮渰渲渶渼湅湉湋湍湑湓湔黔湜湝浈湟湢湣湩湫湮麟湱湲湴涅満沩溍溎溏溛舐漭溠溤溧驯溮溱溲溳溵溷溻溼溽溾滁滃滉滊荥滏稽滕滘汇滝滫滮羼耷卤滹浐煎漈漊漎绎漕漖漘漙沤漜漪漾漥漦漯漰溆漶漷濞潀颍潎潏潕潗潚潝潞潠潦祉疡潲潵滗潸潺潾涠澁澂澃澉澌澍澐澒澔澙渑澣澦澧澨澫澬浍澰澴澶澼熏郁濆濇濈濉濊貊濔疣濜濠濩觞浚濮盥潍濲泺瀁滢渎渖瀌浏瀒瀔濒泸瀛潇潆瀡潴泷濑瀬弥潋瀳瀵瀹瀺瀼沣滠灉灋灒漓灖灏灞灠滦灥灨滟灪蜴灮烬獴灴灸灺炁炅鱿炗炘炙炤炫疽烙钎炯炰炱炲炴炷毁炻烀烋瘴鲳烓烔焙烜烝烳饪烺焃焄耆焌焐焓焗焜焞焠焢焮焯焱焼煁煃煆煇煊熠煍熬煐炜煕暖熏硷霾煚煝煟煠茕矸煨琐炀萁煳煺煻熀熅熇熉罴荧穹炝熘熛熜稔谙烁熤熨熯熰眶蚂颎熳熸熿燀烨燂燄盏燊燋燏燔隼燖焖燠燡灿燨燮燹燻燽燿爇爊爓爚爝爟爨蟾爯爰为爻丬爿牀牁牂牄牋窗牏牓窗釉牚腩蒡虻牠虽蛎牣牤牮牯牲牳牴牷牸牼绊牿靬犂犄犆犇犉犍犎犒荦犗犛犟犠犨犩犪犮犰狳犴犵犺狁甩狃狆狎狒獾狘狙黠狨狩狫狴狷狺狻豕狈蜘猁猇猈猊猋猓猖獗猗猘狰狞犸猞猟獕猭猱猲猳猷猸猹猺玃獀獃獉獍獏獐獒毙獙獚獜獝獞獠獢獣獧鼇蹊狯猃獬豸狝獯鬻獳犷猕猡玁菟玅玆玈珉糁禛郅玍玎玓瓅玔玕玖玗玘玞玠玡玢玤玥玦珏瑰玭玳瑁玶玷玹玼珂珇珈瑚珌馐馔珔珖珙珛珞珡珣珥珧珩珪佩珶珷珺珽琀琁陨玡琇琖琚琠琤琦琨琫琬琭琮琯琰琱琲琅琴珐珲瑀瑂瑄瑉玮瑑瑔瑗瑢瑭瑱瑲瑳瑽瑾瑿璀璨璁璅璆璈琏璊璐璘璚璝璟璠璡璥瑷璩璪璫璯璲玙璸璺璿瓀璎瓖瓘瓒瓛脐瓞瓠瓤瓧瓩瓮瓰瓱瓴瓸瓻瓼甀甁甃甄甇甋甍甎甏甑甒甓甔瓮甖甗饴蔗甙诧钜粱盎锈团甡褥産甪甬甭甮宁铠甹甽甾甿畀畁畇畈畊畋畎畓畚畛畟鄂畤畦畧荻畯畳畵畷畸畽畾疃叠疋疍疎箪疐疒疕疘疝疢疥疧疳疶疿痁痄痊痌痍痏痐痒痔痗瘢痚痠痡痣痦痩痭痯痱痳痵痻痿瘀痖瘃瘈瘉瘊瘌瘏瘐痪瘕瘖瘙瘚瘛疭瘜瘝瘗瘠瘥瘨瘭瘆瘯瘰疬瘳疠瘵瘸瘺瘘瘼癃痨痫癈癎癐癔癙癜癠疖症癞蟆癪瘿痈発踔绀蔫酵皙砬砒翎翳蔹钨镴皑鹎驹暨粤褶皀皁荚皃镈皈皌皋皒朱皕皖皘皜皝皞皤皦皨皪皫皭糙绽皴皲皻皽盅盋碗盍盚盝踞盦盩秋千盬盭眦睁瞤盯盱眙裰盵盻睐眂眅眈眊県眑眕眚眛眞眢眣眭眳眴眵眹瞓眽郛睃睅睆睊睍睎困睒睖睙睟睠睢睥睪睾睯睽睾眯瞈瞋瞍逛瞏瞕瞖眍䁖瞟瞠瞢瞫瞭瞳瞵瞷瞹瞽阇瞿眬矉矍铄矔矗矙瞩矞矟矠矣矧矬矫矰矱硪碇磙罅舫阡、矼矽礓砃砅砆砉砍砑砕砝砟砠砢砦砧砩砫砮砳艏砵砹砼硇硌硍硎硏硐硒硜硖砗磲茚钡硭硻硾碃碉碏碣碓碔碞碡碪碫碬砀碯碲砜碻礴磈磉磎硙磔磕磖磛磟磠磡磤磥蹭磪磬磴磵磹磻硗礀硚礅礌礐礚礜礞礤礧礮砻礲礵礽礿祂祄祅祆禳祊祍祏祓祔祕祗祘祛祧祫祲祻祼饵脔锢禂禇禋祦禔祎隋禖禘禚禜禝禠祃禢禤禥禨禫祢禴禸秆秈秊闱飒秋秏秕笈蘵赁秠秣秪秫秬秭秷秸稊稌稍稑稗稙稛稞稬秸稲稹稼颡稿穂穄穇穈穉穋稣贮穏穜穟秾穑穣穤穧穨穭穮穵穸窿阒窀窂窅窆窈窕窊窋窌窒窗窔窞窣窬黩蹙窑窳窴窵窭窸窗竁竃竈竑竜并竦竖篦篾笆鲛竾笉笊笎笏笐靥笓笤箓笪笫笭笮笰笱笲笳笵笸笻筀筅筇筈筎筑筘筠筤筥筦笕筒筭箸筰筱筳筴宴筸箂个箊箎箑箒箘箙箛箜篌箝箠箬镞箯箴箾篁筼筜篘篙篚篛篜篝篟篠篡篢篥篧篨篭篰篲筚篴篶篹篼箦簁簃簆簉簋簌簏簜簟簠簥簦簨簬簰簸簻籊藤籒籓籔签籚篯箨籣籥籧笾簖籫籯芾麴籵籸籹籼粁秕粋粑粔粝粛粞粢粧粨粲粳稗粻粽辟粿糅糆糈糌糍糒糔萼糗蛆蹋糢糨糬粽糯糱籴粜糸糺紃蹼鲣霉纡纨绔纫闽襻紑纰纮锭鸢鹞纴紞紟扎紩紬绂绁纻紽紾绐絁絃絅経絍绗絏缡褵絓絖絘絜绚絣螯絪絫聒絰絵绝絺絻絿綀绡綅绠绨绣綌綍綎捆綖綘継続缎绻綦綪线綮綯绾罟蝽綷縩绺绫緁绲緅緆缁绯緌緎総緑绱緖缃缄缂绵缗緤褓缌纂緪緰缑缈缏缇縁縃縄萦缙缒縏缣縕缞縚缜缟缛縠縡縢縦绦縯縰骋缧縳纤缦絷缥縻衙縿繄缫繈繊繋繐缯繖繘繙繠缋繣繨缰缲繸繻缱纁纆纇缬缵纩纑纕缵纙纚纛缾罃罆坛罋罂罎罏罖罘罛罝罠罣罥罦罨罫罭锾罳罶罹罻罽罿羂羃羇芈蕉51鸵羑羖羌羜羝羢羣羟羧羭羮羰羱羵羶羸藜鲐翀翃翅翊翌翏翕翛翟翡翣翥翦跹翪翫翚翮翯翱翽翾翿板饕鸹锨耋耇耎耏专耒耜耔耞耡耤耨耩耪耧耰鬓耵聍聃聆聎聝聡聦聱聴聂聼阈聿肄肏肐肕腋肙肜肟肧胛肫肬肭肰肴肵肸肼胊胍胏胑胔胗胙胝胠铨胤胦胩胬胭胯胰胲胴胹胻胼胾脇脘脝脞脡脣脤脥脧脰脲脳腆腊腌臜腍腒腓胨腜腠脶腥腧腬腯踝蹬镣腴腶蠕诽膂腽嗉膇膋膔腘膗膙膟黐膣膦膫膰膴膵膷脍臃臄臇臈臌臐臑臓膘臖臙臛臝臞臧蓐诩臽臾臿舀舁鳑鲏舋舎舔舗馆舝舠舡舢舨舭舲舳舴舸舺艁艄艅艉艋艑艕艖艗艘艚艜艟艣舣艨艩舻艬艭荏艴艳艸艹艻艿芃芄芊萰陂藭芏芔芘芚蕙芟芣芤茉芧芨芩芪芮芰鲢芴芷芸荛豢芼芿苄苒苘苙苜蓿苠苡苣荬苤苎苪镑苶苹苺苻苾茀茁范蠡萣茆茇茈茌茍茖茞茠茢茥茦菰茭茯茳藨茷藘茼荁荄荅荇荈菅蜢鸮荍荑荘豆荵荸荠莆莒莔莕莘莙莚莛莜莝莦莨菪莩莪莭莰莿菀菆菉菎菏菐菑菓菔芲菘菝菡菢菣菥蓂菧菫毂蓥菶菷菹醢菺菻菼菾萅萆苌萋萏萐萑萜萩萱萴莴扁萻葇葍葎葑荭葖葙葠葥苇葧葭药葳葴葶葸葹葽蒄蒎莼茏薹莅蒟蒻蒢蒦蒨蒭藁蒯蒱鉾蒴蒹蒺蒽荪蓁蓆蓇蓊蓌蓍蓏蓓蓖蓧蓪蓫荜跣藕苁蓰蓱莼蓷蓺蓼蔀蔂蔃蔆蔇蔉蔊蔋蔌蔎蔕蔘蔙蒌蔟锷蒋雯茑蔯蔳麻蔵蔸蔾荨蒇蕋蕍荞蕐蕑芸莸蕖蕗蕝蕞蕠蕡蒉蕣蕤蕨蕳蓣蕸蕺蕻薀薁薃薅薆荟薉芗薏薐蔷薖薘剃谔钗薜薠薢薤薧薨薫薬薳薶薷薸薽薾薿藄藇藋荩藐藙藚藟藦藳藴苈藷藾蘀蘁蕲苹蘗蘘蘝蘤蘧蘩蘸蘼虀虆虍蟠虒虓虖虡虣虥虩虬虰蛵蛇虷鳟虺虼蚆蚈蚋蚓蚔蚖蚘蚜蚡蚣蚧蚨蚩蚪蚯蚰蜒蚱蚳蚶蚹蚺蚻蚿蛀蛁蛄蛅蝮蛌蛍蛐蟮蛑蛓蛔蛘蛚蛜蛡蛣蜊蛩蛱蜕螫蜅蚬蜈蝣蜋蜍蜎蜑蠊蜛饯蜞蜣蜨蜩蜮蜱蜷蜺蜾蜿蝀蝃蝋蝌蝍蝎蝏蝗蝘蝙蝝鲼蝡蝤蝥猿蝰虻蝲蝴蝻螃蠏蛳螉螋螒螓螗螘螙螚蟥螟螣螥螬螭䗖螾螀蟀蟅蝈蟊蟋蟑蟓蟛蟜蟟蟢虮蟨蟪蟭蛲蟳蛏蟷蟺蟿蠁蠂蠃虿蠋蛴蠓蚝蠗蠙蠚蠛蠜蠧蟏蠩蜂蠮蠰蠲蠵蠸蠼蠽衁衄衄衇衈衉衋衎衒同衖胡衞裳钩衭衲衵衹衺衿袈裟袗袚袟袢袪袮袲袴袷袺袼褙袽裀裉袅裋夹裍裎裒裛裯裱裲裴裾褀褂褉褊裈褎褐褒褓褔褕袆褚褡褢褦褧褪褫袅褯褰褱裆褛褽褾襁褒襆裥襉襋襌襏襚襛襜裣襞襡襢褴襦襫襬襭襮襕襶襼襽襾覂覃覅霸覉覊覌覗觇覚覜觍觎覧覩觊觏覰観觌觔觕觖觜觽觝觡酲觩觫觭觱觳觯觷觼觾觿言赅讣訇訏訑訒诂讬訧訬訳訹证訾詀詅诋毁詈詊讵詑诒诐詗诎察詨诜詶詸詹詻诙诖誂誃诔锄诓誋诳诶悖誙诮诰誧説読誯谇訚谄谆諆諌诤诹诼諕谂谀諝谝諟喧谥諴諵谌谖誊謆謇歌謍謏謑谡谥謡謦謪谪讴謷謼谩哗譅譆譈譊讹譒撰谮鑫譞噪譩谵譬譱譲谴譸譹谫讅讆詟䜩雠讐谗谶讙谠讟谽豁豉豇岂豊豋豌豏豔豞豖豗豜豝豣豦豨豭豱豳豵豶豷豺豻貅貆狸猊貔貘䝙貜貤餍贳餸贶贲赂賏赊赇赒賝赓赕賨赍斗賮賵賸赚赙赜赟贉赆赑贕赝赬赭赱赳迄趁趂趄趐趑趒趔趡趦趫趮趯趱趴趵趷趹趺趿跁跂跅跆踬跄跐跕跖跗跙跛跦跧跩跫跬跮跱跲跴跺跼跽踅踆踈踉踊踒踖踘踜踟躇蹰踠踡踣踤踥踦踧跷踫踮逾踱踊踶踹踺踼踽躞蹁蹂躏蹎蹐蹓蹔跸蹚蹜蹝迹蹠蹡蹢跶蹧蹩蹪蹯鞠蹽躃躄躅踌跻躐踯跞躘躙躗躝躠蹑躜躧躩躭躰躬躶軃軆辊軏轫軘軜軝腭転軥軨軭軱轱辘軷轵轺軽軿輀輂辇辂辁輈挽輗辄辎辋輠輤輬輭輮辏輴輵輶輹輼辗辒轇轏轑轒辚轕轖轗轘轙轝轞轹轳罪辣辞辵辶辺込辿迅迋迍麿迓迣迤逦迥迨迮迸迺迻迿逄逅逌逍逑逓迳逖逡逭逯逴逶逹遄遅侦遘遛遝遢遨遫遯遰遴绕遹遻邂邅邉邋邎邕邗邘邛邠邢邧邨邯郸邰邲邳邴邶邷邽邾邿郃郄郇郈郔郕郗郙郚郜郝郞郏郠郢郪郫郯郰郲郳郴郷郹郾郿鄀鄄郓鄇鄈鄋鄍鄎鄏鄐鄑邹邬鄕郧鄗鄘鄚鄜鄞鄠鄢鄣鄤鄦鄩鄫鄬鄮鄯鄱郐鄷鄹邝鄻鄾鄿酃酅酆酇郦酊酋酎酏酐酣酔酕醄酖酗酞酡酢酤酩酴酹酺醁醅醆醊醍醐醑醓醖醝酝醡醤醨醪醭醯醰酦醲醴醵醸醹醼醽醾釂酾酽釆釈鲈镏阊钆钇钌钯钋鼢鼹钐钏釪釬釭釱钍釸钕钫鈃钭鈆鈇钚鈊鈌钤钣鈒鈤钬钪鈬铌铈钶铛钹铍钸钿鉄鉆铊铇鉌铋鉏铂钷铆钵鉥钲鉨钼钽鉱鉲鉶铰铒鉼铪銍銎铣銕镂铫铦铑铷銤铱铟銧铥铕铯銭銰焊銶锑锉汞鋂锒鋆鋈鋊铤鋍铗鋐鋑鋕鋘鋙锊锓锔锇铓鋭铖锆锂铽鋳鋹鋺鉴镚钎錀锞锖锫锩錍铔锕錔锱铮锛錞锬锜錤錩錬録铼錼锝钔锴鍉镀鍏鍐铡鍚锻锽锸锲锘鍫鍭鍱鍴锶鍹锗针锺锿镅鎉鎋鎌鎍鎏鎒鎓鎗镉鎚鎞镃鎤铩锼鎭鎯镒镍鎴镓��鎹镎镟鏊镆镠镝鏖铿锵鏚镗镘镛鏠鏦錾镤鏸镪鏻鏽鏾铙鐄鐇鐏铹镦镡鐗馗镫镢镨鐡锎镄鐩镌鐬鐱镭鐶鐻鐽镱鑀鑅镔鑐鑕鑚鑛鑢鑤镥鑪镧鑯鑱鑴鑵镊镢钃镻闫闬闶闳閒闵閗閟阂関合閤哄阆閲阉閺阎阏阍阌暗闉阕阗闑闒闿闘闚阚闟闠闤闼阞阢阤阨阬阯阹阼阽陁陑陔陛陜陡陥陬骘陴険陼陾阴隃隈隒隗隞隠隣隤隩隮隰颧隳隷隹雂雈雉雊雎雑雒雗雘雚雝雟雩雰雱驿霂霅霈霊沾霒霓霙霝霢霣霤霨霩霪霫霮靁叇叆靑靓靣腼靪靮靰靳靷靸靺靼靿鞀鞃鞄鞍鞗鞙鞚鞝鞞鞡鞣鞨鞫鞬鞮鞶鞹鞾鞑韅鞯驮韍韎韔韖韘韝韫韡韣韭韭韱韹韺頀刮頄顸顼頍颀颃颁頖頞頠頫頬颅頯頲颕頼悴顋顑颙颛颜顕顚顜颟顣颥颞飐飑台飓颸飏飖颽颾颿飀飂飚飌翻飡飣饲飥饨饫飮飧飶餀餂饸饹餇餈饽哺馂餖餗餚馄馃餟餠餤餧餩餪餫糊餮糇餲饧馎糕饩馈馊馌馒饇馑馓膳饎饐饘饟馕馘馥馝馡馣骝骡馵馹駃駄駅駆駉駋驽駓驵駗骀驸駜骂骈駪駬骃駴骎駹駽駾騂騄骓騆騉騋骒骐麟騑騒験騕骛騠騢騣騤騧骧騵驺骟騺蓦骖骠骢驆驈骅驌骁驎骣驒驔驖驙驦驩驫骺鲠骫骭肮骱骴骶骷髅骾髁髂髄髆膀髇髑髌髋髙髝髞髟髡髣髧髪髫髭髯髲髳髹髺髽髾鬁鬃鬅鬈鬋鬎鬏鬐鬑鬒鬖鬗鬘鬙鬠鬣斗鬫鬬阄鬯鬰鬲鬵鬷魆魈魊魋魍魉魑魖鳔魛魟魣魦魨魬鲂魵魸鮀鲅鮆鲧鲇鲍鲋鮓鲒鲕鮟鱇鮠鮦鮨鲔鲑鮶鮸鮿鲧鯄鯆鲩鯈鲻鯕鲭鲞鯙鯠鲲鯥鲰鲶鳀鯸鳊鲗䲠鹣鳇鰋鳄鳆鰕鰛鰜鲥鰤鳏鰦鳎鳐鳁鳓鰶鲦鲡鰼鰽鱀鱄鳙鱆鳕鱎鱐鳝鳝鳜鲟鲎鱠鳣鱨鲚鱮鱲鱵鱻鲅鳦凫鳯鳲鳷鳻鴂鴃鴄鸩鴈鴎鸰鴔鴗鸳鸯鸲鹆鸱鴠鴢鸪鴥鸸鹋鴳鸻鴷鴽鵀鵁鸺鹁鵖鵙鹈鹕鹅鵟鵩鹌鵫鵵鵷鵻鹍鶂鶊鶏鶒鹙鶗鶡鶤鶦鶬鶱鹟鶵鶸鶹鹡鶿鹚鷁鷃鷄鷇䴘䴘鷊鷏鹧鷕鹥鸷鷞鷟鸶鹪鹩鷩鷫鷭鹇鹇鸴鷾䴙鸂鸇䴙鸏鸑鸒鸓鸬鹳鸜鹂鹸咸鹾麀麂麃麄麇麋麌麐麑麒麚麛麝麤麸面麫麮麯麰麺麾黁黈黉黢黒黓黕黙黝黟黥黦黧黮黰黱黪黶黹黻黼黾鼋鼂鼃鼅鼈鼍鼏鼐鼒冬鼖鼙鼚鼛鼡鼩鼱鼪鼫鼯鼷鼽齁齆齇齈齉齌赍齑龀齕齗龅齚龇齞龃龉龆齢出齧齩齮齯齰齱齵齾厐龑龒龚龖龘龝龡龢龤'
|
18 |
+
|
19 |
+
traditional_characters = '制咖片型超聲盤鑒定仔點他命書歌粉巾字帳恤手指記憶棒形轉彎溝光○〇㐄㐅㐆㐌㐖毒㐜㐡㐤㐰㐺㑇㑳㒳㒸㔾㗂㗎㝵㞎㞙㞞㠯㢲㢴㤅㥁㥯㨗㫺㬎㮎㮚㮸㲋㲱㲾㳮㵎㵪㶸㷖㷭㹢㹴犬㺢狓㺵㼝㽮㿝䍃䔢䖟䖸䗈䗥䗪䝓䠶䥯䦉䯝䰾魚䲔䳗䳘䵹鼄䶑一對應映射丁不識下兒子做二休世丘之貉並中台原則串為甚謂乾淨了百事無成八變五十些人得道雞升天代如併來去個國政策勁幽靈在歐洲遊蕩接樣蘿蔔坑側化傳價元論醇共再准刀兩斷切分耕耘收穫錢貨物向看舊就緒險刻千金動勞永逸匙零夜半卡通回復返影蹤反常態口咬氣句話同吐快吹周味呼諾嗚品紅鍋哄而散起唱和問三知生熟團漆黑火糟堆場空塊麵塌糊塗塵染壁廂夔已足多情露水大早到晚夫妻當關萬莫開失古恨套所料既往孔見提師要家主審寸陰難買鬥牛小撮部陣局展身層巴掌帆風順席地帶過年計於春頭載四季期被蛇怕井繩度願式份彈頃深前律徑心意念差愁孤行俱全房廳交遮打技長把抓死拿眼淚鼻涕鑰鎖折段抿拍即合掃排掬揮撥擁上入擊洞擲攬改故轍敗文值名斑方面旁族日秋餐隔雅里終父旦時晌會霎間晃暴寒曝更月望垠際朝夕本正經利杯羹東西板枝獨秀根筋桿進條龍服務概模次函數又性程總付步腳印趨登毛拔呵氧氮碳決雌雄波未平派謊言流清楚白準溜煙潭有獲聞是處降琴鶴甲病發可拾沙目然瞭直以相眨穿睹瞥瞬矢的解石鳥神教秉虔誠秘種窩蜂窮竅笑置筆苟勾銷抹殺煞等獎箍節吃箭仇雙鵰詩籌籮筐系列紙級士官統絲毫掛維網盡線微吭響股腦胎脈承腔臂力致效資源址器舉功投般說講規貿易葉障著慎滿皆輸號木電池衣傾鐘高低視仁覺醒覽遺角銀幣觸潰九鼎蔽抄出駟馬追重語破貧洗貫走路安蹴至幾蹶振躍役膽汗較輩輪辭贊退六連遍遞邊針血錘音錯門思閃真倒項栽霧類保護川先驚乍體鬨鱗爪鳴滴泡鄰域黨專鼓作齊炒丑烯亥克內酯冬加奴卯肝炎基尺梁街褲鎬客寵庭巳汝昌烷玲磊糖肇酉醛啷青縣韙良香骨鯛丂七集河市弦喜嘴張舌堵區工業姊妹星架構巧彩扭歪拼湊餘熱曜武州爺浮屠美鄉老階樹葷素碎落能魄鰓鰻珠丄丅丆万俟丈尚摸母娘量管群亞虎必我堂令申件裝伏位博俠義界表女墟臺戲臭皮匠勝諸葛亮賽頂倍催請運算包立叉戟離疫苗土史志演圍揭瓦曬夷姑婆帝村寶爛尖杉鹼屜桌山岔島由紀峽壩庫鎮廢從德後拗湯治旬食明昧曹朋友框欄極權冪曲歸依貓民氟硼氯磷鐵江侗自旅法司洋浦梅園溫暖灣焦班幸用田略番疊皇炮捶硝苯酸腺苷稜草鏡穗跳遠索錦綱聚氰胺聯店胚膲愛色堇紫羅蘭芝茶飯菱雲蟲藏藩亂叛蘇親債凳學座恐戀柱測肌腹衩錐係貂企烏跪叩軍車農題迭都甘油屯奏鍵短阿姨陪姐隻顧茅廬槽駕魂鮮鹿頁其菜單乘任供勢午齒漢組織吊調瀉唇坡城報墳外夸將尉建築岸崗公床揚新劍昇杭林栗校樓標款汽社浣海商館劇院鋼華港機械廣媒環球融第醫科證券綜財樂育游漲猶嶺疏癮瞼確兵領導繳肢膛船艾瑟爾蒼蔡虞傚衫覆訪訴課諭議軌述野鉤限敵鞋頜頷顎饒首齦站例修凡劃垂屆屬崽頦廚拜挫擺放旋削棋榻檻禮沉注滑營獄畫确儀聘花葬詔員跌轄週達酒錨閘陷陸雨雪飛威丌于丹久乏予理評產亢卑亦乎舞己悲矩圓詞害誌但住佞佳便俗信票案幅翁倦倫假偏倚斜虧鬼敲停備傷脾胃僅此像儉匱免宜穴焉戴兼容許凍伯仲負彼晝皂軒輊實刊划顛衛戰哥比省非好黃飾別拘束掩奶睬選擇搖擾煩苦枚寫協厭及格受歡迎約只估侵犯割狀告或缺抗拒挽撤救藥喻磨滅端倪少逆逾越避靠適吉譽吝玉含延咎歹聽啻淵善謀均勻堪忍夠太惹妙妥妨孕症孝術室完納推冠積宣疑辯慄碴稱屈撓屑干涉衡待很忙惡忿怎麼怠急恥恭息悅惑惜惟想愉愧怍慌憤啟懂懈懷材才緊招認扣抵拉捨也罷插揣冒搭撞南牆擴核支攻敢雷攀敬裡嗎需景智暇曾罪遇朽枉止況競爭辱求癒渝溶濟左右袒困補爽特寂寞示弱找謝畏強疾徐痛癢冤符眠睦瞅董何厚云措活疲羞者輕玻璃祥兆禁移稂莠穩佛換答簡結果盟絕縷途給談否羈翼耐肖脛毋寧興舒若菲萊痕跡窠臼虛衰臉兔撒鷹棺範該詳諱抬泰讓鬚眉象眾貲賬費灰賴奇慮訓輟辨菽麥辛近送透逞徒速續逮捕遂遑違遜斧鉞艱醉鏽隨觀棄顯飽脂肪使丏丐幫丒且慢末丕替桃宗王尊涼爵各圖屋脊糧署錄壇吾祿職胄襲君廈丗北壑桐疹損逢陵鷸丙寅戌氨腈唑綸辰酮脫氫酶醚丞丟現掉紗帽弄扯砲碗丠両丣坐存激肩臻蒂蓮悖序驅丨丩丫挺杈髻鬟細介俄伊犁京尼布訂普渡央委監察檢查劑圈設警隊斯督剩震境航舶革防托播促質版蠑螈鋒研藝歷殘消頻譜精密製造陲郵候埔堅壓壢凹匯執府究邦俘攝寮彬狼嶽肺腫庸英訊診埋粒胞括控碼韓暑槍樞砥澳哇牟壽甸鑽探篇簽綴縫繼耳肯照婦埃懸璧軸櫃檯辣擱淺邪跑纖阮陽私囊魔丮丰姿采丱燒丳丵丶丷丸參寨朗桂瑞砂衷霞貌鳳僕艦因嫌宰峰幹絡牌持旨祭禱簿編罰賓辦丼丿乀乂乃乄仰慕盛曠留考驗闊乆乇么醜麼乊湖燃乑乒乓乕乖僻忤戾离謬迕乗危肥劫除隙浪婿乙炔腸酰吡咯鹽乚乛乜嘢卿玄宮尾狐龜塔嶷兄弟泉章霄釘耙乞扎哀憐恕討乢乣乤乥乧乨乩童乪乫乭乳暈汁液瑤漿牙癌突竇罩腐膠豬酪蛋糕菌瘤乴乵乶乷乸乹乺乼乾俸冰嘉噦嚎坤媽屍壘旱枯涸俐渴潮澀煸豆燥爹瘦癟癬瞪袋脆薑貝隆餾乿亀亁叫咕攘扔搞男砸竄蓬麻亃亄亅卻亇遲典今臨繁累卵奉婚聰躬巨與遷添裂副宿歲怪噁尕崙愣杆硅硫鈦鈾錳芑雜異鈉砷胂磺琥珀艙棍簧胡茬盜浩盆販郎腿亍洪亐互欠助勉惠操斥諉繫戶譯亓墓碑刑鈴卅渠繽紛斗米旗憲釩燈徽瘟祖拳福穀豐臟腑綁肉醃苓蘊橋鋪霸顏鬧判噴岡底蛙陘礦亖亙亜罕們娜桑那努哈喀弗烈曼松森杜氏盃奧琛敦戊穆聖裔彙薛孫亟亡佚虜羊牢奮釋卷卸契媾感額睫纏誼趾塞擠紐阻還配馳莊亨洛祚亪享津滬畿郊慈菴枇杷膏亭閣鋥麗亳亶亹誅初責翻瘋偶傑叢稠妖拖寰居吸授慧蝸吞壯魅狗矛盾益渣患憂稀描猿夢暫涯畜禍緣沸搜引擎臣橫紜誰混援蒸獸獅稅剖亻亼亽亾什獻剎邡麽仂仃仄仆富怨仈仉畢昔晨殼紹仍仏仒仕宦仗欺恃腰嘆歎炬梓訖施仙后瓊逝仚仝仞仟悔仡佬償填泊拓撲簇羔購頓欽佩髮棻閫馭養億儆尤藉幀賑凌敘帖李柔剛沃眥睚戒訛取饗讀仨仫仮著泳臥躺韶夏裁仳仵唯賢憑釣誕仿似宋彿諷伀碩盼鵝伄儅伈伉儷柯始娃邁戈坦堡帕茨薩廟瑪莉莎藤霍姆伋伍奢胥廷芳豪伎倆侍汛勒希羲雛伐憩整謨閑閒伕伙伴頤伜伝伢叔恆茲恩翰伱伲侶伶俜悧鼬伸懶縮喇叭伹伺伻伽倻輻伾佀佃佇佈喬妮墨佉盧佌貸劣廉昂檔濃矮傘窪緩耗胸谷迷擋率齲宅沫舍療佐貳佑佔優據鏵嘗呢須魯曉佗佘余坪寺瓜銃僧蒙芒陀龕哼嘔坊姦孽弊揖祟繭縛誓賊佝僂瞀佟你奪趕佡佢佣佤佧賈佪佫佯佰佱潔績釀餚佴捲佶佷佸佹佺佻佼佽佾具喚窘壞娛怒慨硬習慣聾膨脹蔓駭貴痺侀侁侂侃侄侅鴻燕侇侈糜靡侉侌妾侏儒倉鼠侐侑侔侖侘侚鏈侜偎傍鈷循柳葫蘆附価侮罵蔑侯岩截蝕侷貼壺嬛宴捷攜桶箋酌俁狹膝狄俅俉俊俏俎俑俓俔諺俚俛黎健呈固墒增守康箱濕祐鏢鑣槓盒靖膜齡俞豹獵噪孚封札筒託衍鴿剪撰稿煉廠禊練繕葺俯瞰撐衝俲俳俴俵俶俷俺俻俾倀倂倅儲卒惶敷猝逃頡蓄崇隱倌倏忽刺蠟燭噍嚼坍扁抽斃蔥楣灌灶糞背藪賣賠閉霉騰倓倔倖倘倜儻倝借箸挹澆閱倡狂倢倣値倥傯倨��倩匡嗣沖柝珍倬倭寇猩倮倶倷倹勤讚偁偃充偽吏嗓寐惺扮拱芫茜藉虢鈔偈偉晶偌宕距析濾殿疼癱註頗偓偕鴨歇滯偝偟偢忘怡旺偨偩偪偫偭偯偰偱偲偵緝蹄偷減惰漏窺竊偸偺迹傀儡傅傈僳傌籬傎奎琳迪叟芭傒傔傕傖悉荒傜傞傢傣芽逼傭婢傮睨寄檄誦謠頌傴擔辜弓慘蒿悼疤傺傻屄臆巢洩篋羨蓋軋頹傿儸僄僇僉僊働僎僑僔僖僚僝僞僣僤僥僦猴僨僩僬僭僮僯僰僱僵殖籤靜僾僿征隴儁儂儃儇儈朴薄儊儋儌儍儐儓儔儕儗儘儜儞儤儦儩汰哉寡渥裕酷儭儱罐儳儵儹儺儼儽兀臬臲鷲允勛勳宙宵帥憝彞諧嫂鬩暢沛溢盈飢赫兇悍狠猛頑愚妣斬秦遣鞭耀敏榮槃澤爆碟磁禿纜輝霽鹵朵婁孜烽醬勃汀箕裘鉗耶懞蕾徹兌軟遭黜兎児韻媳爸兕觥兗兙兛兜售鍪肚兝兞兟兡兢兣樽殮涅睡稟籍贅泌啡肽奸幕涵澇熵疚眷稃襯訌赴煥椒殲植跏沒試誤猜棲窗肋袖頰兪卦撇鬍岐廓轎疸楓茴瓏廁秩募勺噸寓斤曆畝迫筷釐最淫螺韜兮寬匪篩襄贏軛複兲詐刃堰戎痞蟻餉它冀鑄冂冃円冇冉冊嫁厲礪竭醮冏牧冑冓冔冕冖冗冘冞冢窄抑誣冥冫烘菇蟄冷凝坨橇淇淋炭餅磚磧窖醋雕雹霜冱冶爐艷嘲峻灘淡漠煖颼飲冼冽凃凄愴梗凅凇凈凊凋敝濛凔凜遵汞脢凞几凢処凰凱凵凶焰凸摺刷紋預喪嘍奔巡榜殯芙蓉租籠輯鞘萃凼鋸鑊刁蠻刂娩崩批拆攤掰櫱驟歧顆秒袂贓勿囑忌磋琢膚刈羽刎訟戮舂槳艇刓刖霹靂刜創犢刡恙墅幟筵緻刦刧刨昏默攸尿慾薰潤薰圭刪刮痧鏟刱刲刳刴刵踏磅戳柏槐繡芹莧蝟舟銘鵠鶩刼剁剃辮剄剉履鉛剋剌姻咽哨廊掠桅沿召瞻翅趙卜渺茫郭剒剔剕瀝剚愎毅訥纔剜剝啄採剞剟剡剣剤綵剮腎駛黏剰袍剴紊剷剸剺剽剿劁劂劄劈啪柴扳啦劉奭姥夼昫涓熙禪禹錫翔雁鶚劊劌弩柄蜻蛉劒劓劖劘劙瀾簣賞磯釜晉甜薪逐劦熔紂虐赤囚劬劭労劵効劻劼劾峭艮勅勇勵勍勐臘脖龐漫飼盪粥輒勖勗勘驕餒碌泮雇捐竹騎殊阱勣樸懇謹勦勧勩勯勰勱勲勷勸懲慰誡諫勹芡踐闌匁庇拯粟紮袱裹餃匆遽匈匉匊匋匍匐莖匏匕妝痰膿蛹齋苑烤蹈塘羌熊閥螳螂疆碚竿緯荷茵邙魏匚匜匝匟扶稷匣匭攏匸匹耦匽匾匿卂叮瘡禧軫堤棚迢鈞鍊卄卆遐卉瓷盲瓶噹胱腱裸卋卌卍卐怯污賤鄙齷齪陋卓溪唐梯漁陳棗泥漳潯澗梨芬譙贍轅迦鄭単驢弈洽鰲卛占筮卝卞卟吩啉屎翠厄卣卨卪卬卮榫襖璽綬鈕蚤懼殆篤聳卲帘帙繞卹卼卽厂厎厓厔厖厗奚厘厙厜厝諒厠厤厥厪膩孢厮厰厳厴厹厺粕垢蕪菁厼厾叁悟茸薯叄吵笄悌哺譏坫壟弧芯杠潛嬰芻袁詰貪諜煽饋駁収岳締災賄騙叚叡吻攔蘑蜜訣燧玩硯箏椎藺銅逗驪另覓叨嘮謁杵姓喊嚷囂咚嚀塑尋惱憎擦祇泣滲蝠叱吒咄咤喝籀黛舵舷叵叶鐸懿昭穰苴遼叻叼吁塹嫖賭瞧爬衆抒吅吆夥巹橡滌抱縱摩郡唁墜扇籃膀襪頸吋愾諮酬哭妓媛暗錶韁邇妃羿絮蕃渾拐葵暮隅吔吖啶嗪戚吜嗇噬嚥吟哦詠吠吧唧嗒咐吪雋咀徵燐苞茹鈣哧吮吰吱嘎吲哚吳棟嬌窟孟簫忠晗淞闔閭趼宇吶睛噓拂捧疵熄竽笛糠吼吽呀呂韋矇呃呆笨呇貢呉罄呋喃呎呏呔呠呡癡呣呤呦呧瑛眩扒晬淑姬瑜璇鵑呪呫嗶嚅囁呬呯呰呱呲咧噌鈍呴呶呷呸呺呻哱咻嘯嚕籲坎坷邏呿咁咂咆哮咇咈咋蟹煦珅藹咍咑咒詛咔噠嚓咾噥哩喱咗咠咡咢咣咥咦咨嗟詢咩咪咫嚙齧咭咮咱咲咳嗆嗽咴咷咸咹咺咼喉咿婉慟憫賦矜綠茗藍哂搶瞞哆嗦囉噻啾濱彗哋哌哎唷喲哏哐哞哢哤哪裏哫啼喘哰哲萎蚌哳哶哽哿唄唅唆唈唉唎唏嘩堯棣殤璜睿肅唔睇唕唚唞唣喳唪唬唰喏唲唳唵嘛唶唸唹唻唼唾唿啁啃鸚鵡啅埠棧榷祺舖鞅飆啊啍啎啐啓啕啖啗啜啞祈啢啣啤啥啫啱啲啵啺饑啽噶崑沁喁喂喆裙喈嚨喋喌喎喑喒喓喔粗喙幛慶滋鵲喟喣喤喥喦喧騷喨喩梆喫葡萄喭駝挑嚇碰樅瓣純皰藻趟鉻喵営喹喺喼喿嗀嗃嗄嗅嗈嗉嗊嗍嗐嗑嗔詬嗕嗖嗙嗛嗜痂癖嗝嗡嗤嗥嗨嗩嗬嗯嗰嗲嗵嘰嗷嗹嗾嗿嘀嘁嘂嘅惋嘈峪禾蔭嘊嘌嘏嘐嘒嘓嘖嘚嘜嘞嘟囔嘣嘥嘦嘧嘬嘭這謔嚴敞饞鬆嘵嘶嘷嘸蝦嘹嘻嘽嘿噀噂噅噇噉噎噏噔噗噘噙噚噝噞噢噤蟬皿噩噫噭噯噱噲噳嚏涌灑欲巫霏噷噼嚃嚄嚆抖嚌嚐嚔囌嚚嚜嚞嚟嚦嚬嚭嚮嚯嚲嚳飭按竣苛嚵嚶囀囅囈膪謙囍囒囓囗囘蕭酚飄濺諦囝溯眸紇鑾鶻囟殉囡団囤囥囧囨囪囫圇囬囮囯囲図囶囷囸囹圄圉擬囻囿圀圂圃圊粹蠹赦圌墾圏滾鯡鑿枘圕圛圜圞坯埂壤骸炕祠窯豚紳魠鯪鱉圧握圩圪垯圬圮圯炸岬幔毯祇窨菩溉圳圴圻圾坂坆沾坋坌舛壈昆墊墩椅坒坓坩堝坭坰坱坳坴坵坻坼楊掙涎簾垃垈垌垍垓垔垕垗垚垛垝垣垞垟垤垧垮垵垺垾垿埀畔埄埆埇埈埌殃隍埏埒埕埗埜埡埤埦埧埭埯埰埲埳埴埵埶紼埸培怖樁礎輔埼埽堀訶姪廡堃堄摧磐貞韌砌堈堉堊堋堌堍堎堖堙堞堠礁堧堨輿堭堮蜓摘堲堳堽堿塁塄塈煤塋棵塍塏塒塓綢���鴉沽虱塙塚塝繆塡塢塤塥塩塬塱塲蟎塼塽塾塿墀墁墈墉墐夯増毀墝墠墦漬缽墫墬墮墰墺墻櫥壅壆壊壌壎壒榨蒜壔壕壖壙壚壜壝壠壡壬壭壱売壴壹壻壼寢壿夂夅夆変夊夌漱邑夓腕泄甥禦骼夗夘夙袞瑙妊娠醣梟珊鶯鷺戧幻魘夤蹀祕擂鶇姚宛閨嶼庾撻拇賛蛤裨菠氅漓撈湄蚊霆鯊箐篆篷荊肆舅荔鮃巷慚骰辟邱鎔鐮阪漂燴鯢鰈鱷鴇臚鵬妒峨譚枰晏璣癸祝秤竺牡籟恢罡螻蠍賜絨御梭夬夭砣榆怙枕夶夾餡奄崛葩譎奈賀祀贈奌奐奓奕訢詝奘奜奠奡奣陶奨奩魁奫奬奰媧孩貶隸酥宄狡猾她奼嫣妁氈荼皋膻蠅嬪妄妍嫉媚嬈妗趣妚妞妤礙妬婭妯娌妲妳妵妺姁姅姉姍姒姘姙姜姝姞姣姤姧姫姮娥姱姸姺姽婀娀誘懾脅娉婷娑娓娟娣娭娯娵娶娸娼婊婐婕婞婤婥谿孺婧婪婬婹婺婼婽媁媄媊媕媞媟媠媢媬媮媯媲媵媸媺媻媼眯媿嫄嫈嫋嫏嫕嫗嫘嫚嫜嫠嫡嫦嫩嫪毐嫫嫬嫰嫵嫺嫻嫽嫿嬀嬃嬅嬉耍嬋痴豔嬔嬖嬗嬙嬝嬡嬢嬤嬦嬬嬭幼嬲嬴嬸嬹嬾嬿孀孃孅孌孏曰癲屏孑孓雀孖斟簍謎摺孛矻鳩崮軻祜鸞孥邈毓棠臏孬孭孰孱孳孵泛罔銜孻孿宀宁宂拙株薇掣撫琪瓿榴謐彌宊濂祁瑕宍宏碁宓邸讞実潢町宥宧宨宬徵崎駿掖闕臊煮禽蠶宸豫寀寁寥寃簷庶寎暄磣寔寖寘寙寛寠苫寤肘洱濫蒗陝覈寪弘綽螽寳擅疙瘩晷対檐専尃尅贖絀繚疇釁尌峙醌襟痲碧屁昊槌淘恵瀑牝畑莓缸羚覷蔻髒躁尒尓銳尗尙尜尟尢尥尨尪尬尭尰擒尲尶尷尸尹潽蠖蛾尻釦梢蚴鰭脬蹲屇屌蚵屐屓挪屖屘屙屛屝屢屣巒嶂巖舄屧屨屩屪屭屮戍駐鉀崖嵛巔旮旯楂欖櫸芋茱萸靛麓屴屹屺屼岀岊岌岍阜岑彭鞏岒岝岢嵐岣岧岨岫岱岵岷峁峇峋峒峓峞峠嵋峩峯峱峴峹峿崀崁崆禎崋崌崍嶇崐崒崔嵬巍螢顥崚崞崟崠崢巆崤崦崧殂崬崱崳崴崶崿嵂嵇嵊泗嵌嵎嵒嵓嵗嵙嵞嵡嵩嵫嵯嵴嵼嵾嶁嶃嶄晴嶋嶌嶒嶓嶔嶗嶙嶝嶞嶠嶡嶢嶧嶨嶭嶮嶰嶲嶴嶸巂巃巇巉巋巌巓巘巛滇芎巟巠弋迴巣巤炊擘蜥蟒蠱覡巰蜀彥淖杏茂甫楞巻巽幗巿帛斐鯽蕊帑帔帗帚琉汶帟帡帣帨帬帯帰帷帹暆幃幄幇幋幌幏幘幙幚幞幠幡幢幦幨幩幪幬幭幯幰遙蹉跎餘庚鑑幵幷稚邃庀庁広庄庈庉笠庋跋庖犧庠庤庥鯨庬庱庳庴庵馨衢庹庿廃廄廆廋廌廎廏廐廑廒廕廖廛廝搏鑼廞弛袤廥廧廨廩廱綿踵髓廸廹甌鄴廻廼廾廿躔弁皺弇弌弍弎弐弒弔詭憾薦弝弢弣弤弨弭弮弰弳霖繇燾斌旭溥騫弶弸弼弾彀彄彆纍糾彊彔彖彘彟彠陌彤貽彧繪虹彪炳彫蔚鷗彰癉彲彳彴彷彷徉徨彸彽踩斂旆徂徇徊渭畬鉉裼従筌徘徙徜徠膳甦萌漸徬徭醺徯徳徴潘徻徼忀瘁胖燎怦悸顫扉犀澎湃砰恍惚絞隘忉憚挨餓忐忑忒忖応忝忞耿忡忪忭忮忱忸怩忻悠懣怏遏怔怗怚怛怞懟黍訝怫怭懦怱怲怳怵惕怸怹恁恂恇恉恌恏恒恓恔恘恚恛恝恞恟恠恣恧眄恪恫恬澹恰恿悀悁悃悄悆悊悐悒晦悚悛悜悝悤您悩悪悮悰悱悽惻悳悴悵惘悶悻悾惄愫鍾蒐惆惇惌惎惏惓惔惙惛耄惝瘧濁惥惦惪惲惴惷惸拈愀愃愆愈愊愍愐愑愒愓愔愕愙氓蠢騃昵愜赧愨愬愮愯愷愼慁慂慅慆慇靄慉慊慍慝慥慪慫慬慱慳慴慵慷慼焚憀灼鬱憃憊憋憍眺捏軾憒憔憖憙憧憬憨憪憭憮憯憷憸憹憺懃懅懆邀懊懋懌懍懐懞懠懤懥懨懫懮懰懱毖懵遁樑雍懺懽戁戄戇戉戔戕戛戝戞戠戡戢戣戤戥戦戩戭戯轟戱披菊牖戸戹戺戻戼戽鍬扂楔扃扆扈扊杖牽絹銬鐲賚扐摟攪烊盹瞌跟躉鑔靶鼾払扗玫腮扛扞扠扡扢盔押扤扦扱罾揄綏鞍郤窾扻扼扽抃抆抈抉抌抏瞎抔繯縊擻抜抝択抨摔歉躥牾抶抻搐泵菸拃拄拊髀拋拌脯拎拏拑擢秧沓曳攣迂拚拝拠拡拫拭拮踢拴拶拷攢拽掇芥橐簪摹疔挈瓢驥捺蹻挌挍挎挐揀挓挖掘浚挙揍聵挲挶挾挿捂捃捄捅捆捉捋胳膊揎捌捍捎軀蛛捗捘捙捜捥捩捫捭据捱捻捼捽掀掂掄臀膘掊掎掏掐笙掔掗掞棉芍掤搪闡掫掮掯揉掱掲掽掾揃揅揆搓揌諢揕揗揘揜揝揞揠揥揩揪揫櫫遒麈揰揲揵揶揸揹揺搆搉搊搋搌搎搔搕撼櫓搗搘搠搡搢搣搤搥搦搧搨搬楦褳訕赸搯搰搲搳搴搵搷搽搾搿摀摁摂摃摎摑摒摓跤摙摛摜摞摠摦睺羯摭摮摯摰摲摳摴摶摷摻摽撂撃撅稻撊撋撏鐧潑撕撙撚撝撟撢撣撦撧撩撬撱朔撳蚍蜉撾撿擀擄闖擉缶觚擐擕擖擗擡擣擤澡腚擧擨擩擫擭擯擰擷擸擼擽擿攃攄攆攉攥攐攓攖攙攛每攩攫轡澄攮攰攲攴軼攷砭訐攽碘敁敃敇敉敍敎筏敔敕敖閏誨敜煌敧敪敱敹敺敻敿斁衽斄牒縐謅斉斎斕鶉讕駮鱧斒筲斛斝斞斠斡斢斨斫斮晾沂潟穎絳邵斲斸釳於琅斾斿旀旂旃旄渦旌旎旐旒旓旖旛旝旟旡旣浴旰獺魃旴旹旻旼旽昀昃昄昇昉晰躲澈熹皎皓礬昑昕昜昝昞昡昤暉筍昦昨昰昱昳昴昶昺昻晁蹇隧蔬髦晄晅晒晛晜晞晟晡晢晤晥曦晩萘瑩顗晿暁暋暌暍暐暔暕煅暘暝暠暡曚暦暨暪朦朧暱暲殄馮暵暸暹暻暾曀曄曇曈曌曏曐曖曘曙曛曡曨曩駱曱甴肱曷牘禺錕曽滄耽朁朅朆杪栓誇竟粘絛朊膺朏朐朓朕朘朙瞄覲溘饔飧朠朢朣柵椆澱蝨朩朮朰朱炆璋鈺熾鹮朳槿朶朾朿杅杇杌隉欣釗湛漼楷瀍煜玟纓翱肈舜贄适逵杓杕杗杙荀蘅杝杞脩珓筊杰榔狍閦顰緬莞杲杳眇杴杶杸杻杼枋枌枒枓衾葄翹紓逋枙狸椏枟槁枲枳枴枵枷枸櫞枹枻柁柂柃柅柈柊柎某柑橘柒柘柙柚柜柞櫟柟柢柣柤柩柬柮柰柲橙柶柷柸柺査柿栃栄栒栔栘栝栟栢栩栫栭栱栲栳栴檀栵栻桀驁桁鎂桄桉桋桎梏椹葚桓桔桕桜桟桫欏桭桮桯桲桴桷桹湘溟梃梊梍梐潼梔梘梜梠梡梣梧梩梱梲梳梴梵梹棁棃櫻棐棑棕櫚簑繃蓑棖棘棜棨棩棪棫棬棯棰棱棳棸棹槨棼椀椄苕椈椊椋椌椐椑椓椗検椤椪椰椳椴椵椷椸椽椿楀楄楅篪楋楍楎楗楘楙楛楝楟楠楢楥楨楩楪楫楬楮楯楰楳楸楹楻楽榀榃榊榎槺榕榖榘榛狉莽榜笞榠榡榤榥榦榧榪榭榰榱槤霰榼榾榿槊閂槎槑槔槖様槜槢槥槧槪槭槮槱槲槻槼槾樆樊樏樑樕樗樘樛樟樠樧樨権樲樴樵猢猻樺樻罍樾樿橁橄橆橈笥龠橕橚橛輛橢橤橧豎膈跨橾橿檁檃檇檉檍檎檑檖檗檜檟檠檣檨檫檬檮檳檴檵檸櫂櫆櫌櫛櫜櫝櫡櫧櫨櫪櫬櫳櫹櫺茄櫽欀欂欃欐欑欒欙欞溴欨欬欱欵欶欷歔欸欹欻欼欿歁歃歆艎歈歊蒔蝶歓歕歘歙歛歜歟歠蹦詮鑲蹣跚陞陟歩歮歯歰歳歴璞歺瞑歾歿殀殈殍殑殗殜殙殛殞殢殣殥殪殫殭殰殳荃殷殸殹蛟殻殽謗毆毈毉餵毎毑蕈毗毘毚茛鄧毧毬毳毷毹毽毾毿氂氄氆靴氉氊氌氍氐聊氕氖気氘氙氚氛氜氝氡洶焊痙氤氳氥氦鋁鋅氪烴氬銨痤汪滸漉痘盂碾菖蒲蕹蛭螅氵氷氹氺氽燙氾氿渚汆汊汋汍汎汏汐汔汕褟汙汚汜蘺沼穢衊汧汨汩汭汲汳汴隄汾沄沅沆瀣沇沈葆浸淪湎溺痼痾沌沍沏沐沔沕沘浜畹礫沚沢沬沭沮沰沱灢沴沷籽沺烹濡洄泂肛泅泆湧肓泐泑泒泓泔泖泙泚泜泝泠漩饃濤粼濘蘚鰍泩泫泭泯銖泱泲洇洊涇琵琶荽薊箔洌洎洏洑潄濯洙洚洟洢洣洧洨洩痢滔洫洮洳洴洵洸洹洺洼洿淌蜚浄浉浙贛渫浠浡浤浥淼瀚浬浭翩萍浯浰蜃淀苔蛞蝓蜇螵蛸煲鯉浹浼浽溦涂涊涐涑涒涔滂涖涘涙涪涫涬涮涴涶涷涿淄淅淆淊淒黯淓淙漣淜淝淟淠淢淤淥淦淩猥藿褻淬淮淯淰淳詣淶紡淸淹燉癯綺渇済渉渋渓渕渙渟渢滓渤澥渧渨渮渰渲渶渼湅湉湋湍湑湓湔黔湜湝湞湟湢湣湩湫湮麟湱湲湴湼満溈溍溎溏溛舐漭溠溤溧馴溮溱溲溳溵溷溻溼溽溾滁滃滉滊滎滏稽滕滘滙滝滫滮羼耷滷滹滻煎漈漊漎繹漕漖漘漙漚漜漪漾漥漦漯漰漵漶漷濞潀潁潎潏潕潗潚潝潞潠潦祉瘍潲潵潷潸潺潾潿澁澂澃澉澌澍澐澒澔澙澠澣澦澧澨澫澬澮澰澴澶澼熏郁濆濇濈濉濊貊濔疣濜濠濩觴濬濮盥濰濲濼瀁瀅瀆瀋瀌瀏瀒瀔瀕瀘瀛瀟瀠瀡瀦瀧瀨瀬瀰瀲瀳瀵瀹瀺瀼灃灄灉灋灒灕灖灝灞灠灤灥灨灩灪蜴灮燼獴灴灸灺炁炅魷炗炘炙炤炫疽烙釺炯炰炱炲炴炷燬炻烀烋瘴鯧烓烔焙烜烝烳飪烺焃焄耆焌焐焓焗焜焞焠焢焮焯焱焼煁煃煆煇煊熠煍熬煐煒煕煗燻礆霾煚煝煟煠煢矸煨瑣煬萁煳煺煻熀熅熇熉羆熒穹熗熘熛熜稔諳爍熤熨熯熰眶螞熲熳熸熿燀燁燂燄盞燊燋燏燔隼燖燜燠燡燦燨燮燹燻燽燿爇爊爓爚爝爟爨蟾爯爰爲爻爿爿牀牁牂牄牋牎牏牓牕釉牚腩蒡虻牠雖蠣牣牤牮牯牲牳牴牷牸牼絆牿靬犂犄犆犇犉犍犎犒犖犗犛犟犠犨犩犪犮犰狳犴犵犺狁甩狃狆狎狒獾狘狙黠狨狩狫狴狷狺狻豕狽蜘猁猇猈猊猋猓猖獗猗猘猙獰獁猞猟獕猭猱猲猳猷猸猹猺玃獀獃獉獍獏獐獒獘獙獚獜獝獞獠獢獣獧鼇蹊獪獫獬豸獮獯鬻獳獷獼玀玁菟玅玆玈珉糝禛郅玍玎玓瓅玔玕玖玗玘玞玠玡玢玤玥玦玨瑰玭玳瑁玶玷玹玼珂珇珈瑚珌饈饌珔珖珙珛珞珡珣珥珧珩珪珮珶珷珺珽琀琁隕琊琇琖琚琠琤琦琨琫琬琭琮琯琰琱琲瑯琹琺琿瑀瑂瑄瑉瑋瑑瑔瑗瑢瑭瑱瑲瑳瑽瑾瑿璀璨璁璅璆璈璉璊璐璘璚璝璟璠璡璥璦璩璪璫璯璲璵璸璺璿瓀瓔瓖瓘瓚瓛臍瓞瓠瓤瓧瓩瓮瓰瓱瓴瓸瓻瓼甀甁甃甄甇甋甍甎甏甑甒甓甔甕甖甗飴蔗甙詫鉅粱盎銹糰甡褥産甪甬甭甮甯鎧甹甽甾甿畀畁畇畈畊畋畎畓畚畛畟鄂畤畦畧荻畯畳畵畷畸畽畾疃疉疋疍疎簞疐疒疕疘疝疢疥疧疳疶疿痁痄痊痌痍痏痐痒痔痗瘢痚痠痡痣痦痩痭痯痱痳痵痻痿瘀瘂瘃瘈瘉瘊瘌瘏瘐瘓瘕瘖瘙瘚瘛瘲瘜瘝瘞瘠瘥瘨瘭瘮瘯瘰癧瘳癘瘵瘸瘺瘻瘼癃癆癇癈癎癐癔癙癜癠癤癥癩蟆癪癭癰発踔紺蔫酵皙砬砒翎翳蘞鎢鑞皚鵯駒鱀粵褶皀皁莢皃鎛皈皌皐皒硃皕皖皘皜皝皞皤皦皨皪皫皭糙綻皴皸皻皽盅盋盌盍盚盝踞盦盩鞦韆盬盭眦睜瞤盯盱眙裰盵盻睞眂眅眈眊県眑眕眚眛眞眢眣眭眳眴眵眹瞓眽郛睃睅睆睊睍睎睏睒睖睙睟睠睢睥睪睪睯睽睾瞇瞈瞋瞍逛瞏瞕瞖瞘瞜瞟瞠瞢瞫瞭瞳瞵瞷瞹瞽闍瞿矓矉矍鑠矔矗矙矚矞矟矠矣矧矬矯矰矱硪碇磙��舫阡、矼矽礓砃砅砆砉砍砑砕砝砟砠砢砦砧砩砫砮砳艏砵砹砼硇硌硍硎硏硐硒硜硤硨磲茚鋇硭硻硾碃碉碏碣碓碔碞碡碪碫碬碭碯碲碸碻礡磈磉磎磑磔磕磖磛磟磠磡磤磥蹭磪磬磴磵磹磻磽礀礄礅礌礐礚礜礞礤礧礮礱礲礵礽礿祂祄祅祆禳祊祍祏祓祔祕祗祘祛祧祫祲祻祼餌臠錮禂禇禋禑禔禕隋禖禘禚禜禝禠禡禢禤禥禨禫禰禴禸稈秈秊闈颯秌秏秕笈蘵賃秠秣秪秫秬秭秷秸稊稌稍稑稗稙稛稞稬稭稲稹稼顙稾穂穄穇穈穉穋穌貯穏穜穟穠穡穣穤穧穨穭穮穵穸窿闃窀窂窅窆窈窕窊窋窌窒窓窔窞窣窬黷蹙窰窳窴窵窶窸窻竁竃竈竑竜竝竦竪篦篾笆鮫竾笉笊笎笏笐靨笓笤籙笪笫笭笮笰笱笲笳笵笸笻筀筅筇筈筎筑筘筠筤筥筦筧筩筭筯筰筱筳筴讌筸箂箇箊箎箑箒箘箙箛箜篌箝箠箬鏃箯箴箾篁篔簹篘篙篚篛篜篝篟篠篡篢篥篧篨篭篰篲篳篴篶篹篼簀簁簃簆簉簋簌簏簜簟簠簥簦簨簬簰簸簻籊籐籒籓籔籖籚籛籜籣籥籧籩籪籫籯芾麴籵籸籹籼粁粃粋粑粔糲粛粞粢粧粨粲粳粺粻粽闢粿糅糆糈糌糍糒糔萼糗蛆蹋糢糨糬糭糯糱糴糶糸糺紃蹼鰹黴紆紈絝紉閩襻紑紕紘錠鳶鷂紝紞紟紥紩紬紱紲紵紽紾紿絁絃絅経絍絎絏縭褵絓絖絘絜絢絣螯絪絫聒絰絵絶絺絻絿綀綃綅綆綈綉綌綍綎綑綖綘継続緞綣綦綪綫綮綯綰罟蝽綷縩綹綾緁緄緅緆緇緋緌緎総緑緔緖緗緘緙緜緡緤緥緦纂緪緰緱緲緶緹縁縃縄縈縉縋縏縑縕縗縚縝縞縟縠縡縢縦縧縯縰騁縲縳縴縵縶縹縻衙縿繄繅繈繊繋繐繒繖繘繙繠繢繣繨繮繰繸繻繾纁纆纇纈纉纊纑纕纘纙纚纛缾罃罆罈罋罌罎罏罖罘罛罝罠罣罥罦罨罫罭鍰罳罶罹罻罽罿羂羃羇羋蕉51鴕羑羖羗羜羝羢羣羥羧羭羮羰羱羵羶羸藜鮐翀翃翄翊翌翏翕翛翟翡翣翥翦躚翪翫翬翮翯翺翽翾翿闆饕鴰鍁耋耇耎耏耑耒耜耔耞耡耤耨耩耪耬耰鬢耵聹聃聆聎聝聡聦聱聴聶聼閾聿肄肏肐肕腋肙肜肟肧胛肫肬肭肰肴肵肸肼胊胍胏胑胔胗胙胝胠銓胤胦胩胬胭胯胰胲胴胹胻胼胾脇脘脝脞脡脣脤脥脧脰脲脳腆腊腌臢腍腒腓腖腜腠腡腥腧腬腯踝蹬鐐腴腶蠕誹膂膃膆膇膋膔膕膗膙膟黐膣膦膫膰膴膵膷膾臃臄臇臈臌臐臑臓臕臖臙臛臝臞臧蓐詡臽臾臿舀舁鰟鮍舋舎舔舗舘舝舠舡舢舨舭舲舳舴舸舺艁艄艅艉艋艑艕艖艗艘艚艜艟艣艤艨艩艫艬艭荏艴艶艸艹艻艿芃芄芊萰陂藭芏芔芘芚蕙芟芣芤茉芧芨芩芪芮芰鰱芴芷芸蕘豢芼芿苄苒苘苙苜蓿苠苡苣蕒苤苧苪鎊苶苹苺苻苾茀茁范蠡萣茆茇茈茌茍茖茞茠茢茥茦菰茭茯茳藨茷藘茼荁荄荅荇荈菅蜢鴞荍荑荘荳荵荸薺莆莒莔莕莘莙莚莛莜莝莦莨菪莩莪莭莰莿菀菆菉菎菏菐菑菓菔菕菘菝菡菢菣菥蓂菧菫轂鎣菶菷菹醢菺菻菼菾萅萆萇萋萏萐萑萜萩萱萴萵萹萻葇葍葎葑葒葖葙葠葥葦葧葭葯葳葴葶葸葹葽蒄蒎蒓蘢薹蒞蒟蒻蒢蒦蒨蒭藁蒯蒱鉾蒴蒹蒺蒽蓀蓁蓆蓇蓊蓌蓍蓏蓓蓖蓧蓪蓫蓽跣藕蓯蓰蓱蓴蓷蓺蓼蔀蔂蔃蔆蔇蔉蔊蔋蔌蔎蔕蔘蔙蔞蔟鍔蔣雯蔦蔯蔳蔴蔵蔸蔾蕁蕆蕋蕍蕎蕐蕑蕓蕕蕖蕗蕝蕞蕠蕡蕢蕣蕤蕨蕳蕷蕸蕺蕻薀薁薃薅薆薈薉薌薏薐薔薖薘薙諤釵薜薠薢薤薧薨薫薬薳薶薷薸薽薾薿藄藇藋藎藐藙藚藟藦藳藴藶藷藾蘀蘁蘄蘋蘗蘘蘝蘤蘧蘩蘸蘼虀虆虍蟠虒虓虖虡虣虥虩虯虰蛵虵虷鱒虺虼蚆蚈蚋蚓蚔蚖蚘蚜蚡蚣蚧蚨蚩蚪蚯蚰蜒蚱蚳蚶蚹蚺蚻蚿蛀蛁蛄蛅蝮蛌蛍蛐蟮蛑蛓蛔蛘蛚蛜蛡蛣蜊蛩蛺蛻螫蜅蜆蜈蝣蜋蜍蜎蜑蠊蜛餞蜞蜣蜨蜩蜮蜱蜷蜺蜾蜿蝀蝃蝋蝌蝍蝎蝏蝗蝘蝙蝝鱝蝡蝤蝥蝯蝰蝱蝲蝴蝻螃蠏螄螉螋螒螓螗螘螙螚蟥螟螣螥螬螭螮螾螿蟀蟅蟈蟊蟋蟑蟓蟛蟜蟟蟢蟣蟨蟪蟭蟯蟳蟶蟷蟺蟿蠁蠂蠃蠆蠋蠐蠓蠔蠗蠙蠚蠛蠜蠧蠨蠩蠭蠮蠰蠲蠵蠸蠼蠽衁衂衄衇衈衉衋衎衒衕衖衚衞裳鈎衭衲衵衹衺衿袈裟袗袚袟袢袪袮袲袴袷袺袼褙袽裀裉裊裋裌裍裎裒裛裯裱裲裴裾褀褂褉褊褌褎褐褒褓褔褕褘褚褡褢褦褧褪褫褭褯褰褱襠褸褽褾襁襃襆襇襉襋襌襏襚襛襜襝襞襡襢襤襦襫襬襭襮襴襶襼襽襾覂覃覅覇覉覊覌覗覘覚覜覥覦覧覩覬覯覰観覿觔觕觖觜觽觝觡酲觩觫觭觱觳觶觷觼觾觿言賅訃訇訏訑訒詁託訧訬訳訹証訾詀詅詆譭詈詊詎詑詒詖詗詘詧詨詵詶詸詹詻詼詿誂誃誄鋤誆誋誑誒誖誙誚誥誧説読誯誶誾諂諄諆諌諍諏諑諕諗諛諝諞諟諠諡諴諵諶諼謄謆謇謌謍謏謑謖謚謡謦謪謫謳謷謼謾譁譅譆譈譊譌譒譔譖鑫譞譟譩譫譬譱譲譴譸譹譾讅讆讋讌讎讐讒讖讙讜讟谽豁豉豇豈豊豋豌豏豔豞豖豗豜豝豣豦豨豭豱豳豵豶豷豺豻貅貆貍貎貔貘貙貜貤饜貰餸貺賁賂賏賒賕賙賝賡賧賨賫鬭賮賵賸賺賻賾贇贉贐贔贕贗赬赭赱赳迄趁趂趄趐趑趒趔趡趦趫趮趯趲趴趵趷趹趺趿跁跂跅跆躓蹌跐跕跖跗跙跛跦跧跩跫跬跮跱跲跴跺跼跽踅踆踈踉踊踒���踘踜踟躇躕踠踡踣踤踥踦踧蹺踫踮踰踱踴踶踹踺踼踽躞蹁蹂躪蹎蹐蹓蹔蹕蹚蹜蹝蹟蹠蹡蹢躂蹧蹩蹪蹯鞠蹽躃躄躅躊躋躐躑躒躘躙躛躝躠躡躦躧躩躭躰躳躶軃軆輥軏軔軘軜軝齶転軥軨軭軱軲轆軷軹軺軽軿輀輂輦輅輇輈輓輗輙輜輞輠輤輬輭輮輳輴輵輶輹輼輾轀轇轏轑轒轔轕轖轗轘轙轝轞轢轤辠辢辤辵辶辺込辿迅迋迍麿迓迣迤邐迥迨迮迸迺迻迿逄逅逌逍逑逓逕逖逡逭逯逴逶逹遄遅遉遘遛遝遢遨遫遯遰遴遶遹遻邂邅邉邋邎邕邗邘邛邠邢邧邨邯鄲邰邲邳邴邶邷邽邾邿郃郄郇郈郔郕郗郙郚郜郝郞郟郠郢郪郫郯郰郲郳郴郷郹郾郿鄀鄄鄆鄇鄈鄋鄍鄎鄏鄐鄑鄒鄔鄕鄖鄗鄘鄚鄜鄞鄠鄢鄣鄤鄦鄩鄫鄬鄮鄯鄱鄶鄷鄹鄺鄻鄾鄿酃酅酆酇酈酊酋酎酏酐酣酔酕醄酖酗酞酡酢酤酩酴酹酺醁醅醆醊醍醐醑醓醖醝醞醡醤醨醪醭醯醰醱醲醴醵醸醹醼醽醾釂釃釅釆釈鱸鎦閶釓釔釕鈀釙鼢鼴釤釧釪釬釭釱釷釸釹鈁鈃鈄鈆鈇鈈鈊鈌鈐鈑鈒鈤鈥鈧鈬鈮鈰鈳鐺鈸鈹鈽鈿鉄鉆鉈鉋鉌鉍鉏鉑鉕鉚鉢鉥鉦鉨鉬鉭鉱鉲鉶鉸鉺鉼鉿銍銎銑銕鏤銚銛銠銣銤銥銦銧銩銪銫銭銰銲銶銻銼銾鋂鋃鋆鋈鋊鋌鋍鋏鋐鋑鋕鋘鋙鋝鋟鋦鋨鋩鋭鋮鋯鋰鋱鋳鋹鋺鋻鏰鐱錀錁錆錇錈錍錏錒錔錙錚錛錞錟錡錤錩錬録錸錼鍀鍆鍇鍉鍍鍏鍐鍘鍚鍛鍠鍤鍥鍩鍫鍭鍱鍴鍶鍹鍺鍼鍾鎄鎇鎉鎋鎌鎍鎏鎒鎓鎗鎘鎚鎞鎡鎤鎩鎪鎭鎯鎰鎳鎴鎵鎸鎹鎿鏇鏊鏌鏐鏑鏖鏗鏘鏚鏜鏝鏞鏠鏦鏨鏷鏸鏹鏻鏽鏾鐃鐄鐇鐏鐒鐓鐔鐗馗鐙鐝鐠鐡鐦鐨鐩鐫鐬鐱鐳鐶鐻鐽鐿鑀鑅鑌鑐鑕鑚鑛鑢鑤鑥鑪鑭鑯鑱鑴鑵鑷钁钃镻閆閈閌閎閒閔閗閟閡関閤閤閧閬閲閹閺閻閼閽閿闇闉闋闐闑闒闓闘闚闞闟闠闤闥阞阢阤阨阬阯阹阼阽陁陑陔陛陜陡陥陬騭陴険陼陾隂隃隈隒隗隞隠隣隤隩隮隰顴隳隷隹雂雈雉雊雎雑雒雗雘雚雝雟雩雰雱驛霂霅霈霊霑霒霓霙霝霢霣霤霨霩霪霫霮靁靆靉靑靚靣靦靪靮靰靳靷靸靺靼靿鞀鞃鞄鞌鞗鞙鞚鞝鞞鞡鞣鞨鞫鞬鞮鞶鞹鞾韃韅韉馱韍韎韔韖韘韝韞韡韣韭韮韱韹韺頀颳頄頇頊頍頎頏頒頖頞頠頫頬顱頯頲頴頼顇顋顑顒顓顔顕顚顜顢顣顬顳颭颮颱颶颸颺颻颽颾颿飀飂飈飌飜飡飣飤飥飩飫飮飱飶餀餂餄餎餇餈餑餔餕餖餗餚餛餜餟餠餤餧餩餪餫餬餮餱餲餳餺餻餼餽餿饁饅饇饉饊饍饎饐饘饟饢馘馥馝馡馣騮騾馵馹駃駄駅駆駉駋駑駓駔駗駘駙駜駡駢駪駬駰駴駸駹駽駾騂騄騅騆騉騋騍騏驎騑騒験騕騖騠騢騣騤騧驤騵騶騸騺驀驂驃驄驆驈驊驌驍驎驏驒驔驖驙驦驩驫骺鯁骫骭骯骱骴骶骷髏骾髁髂髄髆髈髐髑髕髖髙髝髞髟髡髣髧髪髫髭髯髲髳髹髺髽髾鬁鬃鬅鬈鬋鬎鬏鬐鬑鬒鬖鬗鬘鬙鬠鬣鬪鬫鬬鬮鬯鬰鬲鬵鬷魆魈魊魋魍魎魑魖鰾魛魟魣魦魨魬魴魵魸鮀鮁鮆鮌鮎鮑鮒鮓鮚鮞鮟鱇鮠鮦鮨鮪鮭鮶鮸鮿鯀鯄鯆鯇鯈鯔鯕鯖鯗鯙鯠鯤鯥鯫鯰鯷鯸鯿鰂鰆鶼鰉鰋鰐鰒鰕鰛鰜鰣鰤鰥鰦鰨鰩鰮鰳鰶鰷鱺鰼鰽鱀鱄鱅鱆鱈鱎鱐鱓鱔鱖鱘鱟鱠鱣鱨鱭鱮鱲鱵鱻鲅鳦鳧鳯鳲鳷鳻鴂鴃鴄鴆鴈鴎鴒鴔鴗鴛鴦鴝鵒鴟鴠鴢鴣鴥鴯鶓鴳鴴鴷鴽鵀鵁鵂鵓鵖鵙鵜鶘鵞鵟鵩鵪鵫鵵鵷鵻鵾鶂鶊鶏鶒鶖鶗鶡鶤鶦鶬鶱鶲鶵鶸鶹鶺鶿鷀鷁鷃鷄鷇鷈鷉鷊鷏鷓鷕鷖鷙鷞鷟鷥鷦鷯鷩鷫鷭鷳鷴鷽鷾鷿鸂鸇鸊鸏鸑鸒鸓鸕鸛鸜鸝鹸鹹鹺麀麂麃麄麇麋麌麐麑麒麚麛麝麤麩麪麫麮麯麰麺麾黁黈黌黢黒黓黕黙黝黟黥黦黧黮黰黱黲黶黹黻黼黽黿鼂鼃鼅鼈鼉鼏鼐鼒鼕鼖鼙鼚鼛鼡鼩鼱鼪鼫鼯鼷鼽齁齆齇齈齉齌齎齏齔齕齗齙齚齜齞齟齬齠齢齣齧齩齮齯齰齱齵齾龎龑龒龔龖龘龝龡龢龤'
|
20 |
+
|
21 |
+
assert len(simplified_charcters) == len(simplified_charcters)
|
22 |
+
|
23 |
+
s2t_dict = {}
|
24 |
+
t2s_dict = {}
|
25 |
+
for i, item in enumerate(simplified_charcters):
|
26 |
+
s2t_dict[item] = traditional_characters[i]
|
27 |
+
t2s_dict[traditional_characters[i]] = item
|
28 |
+
|
29 |
+
|
30 |
+
def tranditional_to_simplified(text: str) -> str:
|
31 |
+
return "".join(
|
32 |
+
[t2s_dict[item] if item in t2s_dict else item for item in text])
|
33 |
+
|
34 |
+
|
35 |
+
def simplified_to_traditional(text: str) -> str:
|
36 |
+
return "".join(
|
37 |
+
[s2t_dict[item] if item in s2t_dict else item for item in text])
|
38 |
+
|
39 |
+
|
40 |
+
if __name__ == "__main__":
|
41 |
+
text = "一般是指存取一個應用程式啟動時始終顯示在網站或網頁瀏覽器中的一個或多個初始網頁等畫面存在的站點"
|
42 |
+
print(text)
|
43 |
+
text_simple = tranditional_to_simplified(text)
|
44 |
+
print(text_simple)
|
45 |
+
text_traditional = simplified_to_traditional(text_simple)
|
46 |
+
print(text_traditional)
|
SongBloom/g2p/cn_zh_g2p/zh_normalization/chronology.py
ADDED
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import re
|
15 |
+
|
16 |
+
from .num import DIGITS
|
17 |
+
from .num import num2str
|
18 |
+
from .num import verbalize_cardinal
|
19 |
+
from .num import verbalize_digit
|
20 |
+
|
21 |
+
|
22 |
+
def _time_num2str(num_string: str) -> str:
|
23 |
+
"""A special case for verbalizing number in time."""
|
24 |
+
result = num2str(num_string.lstrip('0'))
|
25 |
+
if num_string.startswith('0'):
|
26 |
+
result = DIGITS['0'] + result
|
27 |
+
return result
|
28 |
+
|
29 |
+
|
30 |
+
# 时刻表达式
|
31 |
+
RE_TIME = re.compile(r'([0-1]?[0-9]|2[0-3])'
|
32 |
+
r':([0-5][0-9])'
|
33 |
+
r'(:([0-5][0-9]))?')
|
34 |
+
|
35 |
+
# 时间范围,如8:30-12:30
|
36 |
+
RE_TIME_RANGE = re.compile(r'([0-1]?[0-9]|2[0-3])'
|
37 |
+
r':([0-5][0-9])'
|
38 |
+
r'(:([0-5][0-9]))?'
|
39 |
+
r'(~|-)'
|
40 |
+
r'([0-1]?[0-9]|2[0-3])'
|
41 |
+
r':([0-5][0-9])'
|
42 |
+
r'(:([0-5][0-9]))?')
|
43 |
+
|
44 |
+
|
45 |
+
def replace_time(match) -> str:
|
46 |
+
"""
|
47 |
+
Args:
|
48 |
+
match (re.Match)
|
49 |
+
Returns:
|
50 |
+
str
|
51 |
+
"""
|
52 |
+
|
53 |
+
is_range = len(match.groups()) > 5
|
54 |
+
|
55 |
+
hour = match.group(1)
|
56 |
+
minute = match.group(2)
|
57 |
+
second = match.group(4)
|
58 |
+
|
59 |
+
if is_range:
|
60 |
+
hour_2 = match.group(6)
|
61 |
+
minute_2 = match.group(7)
|
62 |
+
second_2 = match.group(9)
|
63 |
+
|
64 |
+
result = f"{num2str(hour)}点"
|
65 |
+
if minute.lstrip('0'):
|
66 |
+
if int(minute) == 30:
|
67 |
+
result += "半"
|
68 |
+
else:
|
69 |
+
result += f"{_time_num2str(minute)}分"
|
70 |
+
if second and second.lstrip('0'):
|
71 |
+
result += f"{_time_num2str(second)}秒"
|
72 |
+
|
73 |
+
if is_range:
|
74 |
+
result += "至"
|
75 |
+
result += f"{num2str(hour_2)}点"
|
76 |
+
if minute_2.lstrip('0'):
|
77 |
+
if int(minute) == 30:
|
78 |
+
result += "半"
|
79 |
+
else:
|
80 |
+
result += f"{_time_num2str(minute_2)}分"
|
81 |
+
if second_2 and second_2.lstrip('0'):
|
82 |
+
result += f"{_time_num2str(second_2)}秒"
|
83 |
+
|
84 |
+
return result
|
85 |
+
|
86 |
+
|
87 |
+
RE_DATE = re.compile(r'(\d{4}|\d{2})年'
|
88 |
+
r'((0?[1-9]|1[0-2])月)?'
|
89 |
+
r'(((0?[1-9])|((1|2)[0-9])|30|31)([日号]))?')
|
90 |
+
|
91 |
+
|
92 |
+
def replace_date(match) -> str:
|
93 |
+
"""
|
94 |
+
Args:
|
95 |
+
match (re.Match)
|
96 |
+
Returns:
|
97 |
+
str
|
98 |
+
"""
|
99 |
+
year = match.group(1)
|
100 |
+
month = match.group(3)
|
101 |
+
day = match.group(5)
|
102 |
+
result = ""
|
103 |
+
if year:
|
104 |
+
result += f"{verbalize_digit(year)}年"
|
105 |
+
if month:
|
106 |
+
result += f"{verbalize_cardinal(month)}月"
|
107 |
+
if day:
|
108 |
+
result += f"{verbalize_cardinal(day)}{match.group(9)}"
|
109 |
+
return result
|
110 |
+
|
111 |
+
|
112 |
+
# 用 / 或者 - 分隔的 YY/MM/DD 或者 YY-MM-DD 日期
|
113 |
+
RE_DATE2 = re.compile(
|
114 |
+
r'(\d{4})([- /.])(0[1-9]|1[012])\2(0[1-9]|[12][0-9]|3[01])')
|
115 |
+
|
116 |
+
|
117 |
+
def replace_date2(match) -> str:
|
118 |
+
"""
|
119 |
+
Args:
|
120 |
+
match (re.Match)
|
121 |
+
Returns:
|
122 |
+
str
|
123 |
+
"""
|
124 |
+
year = match.group(1)
|
125 |
+
month = match.group(3)
|
126 |
+
day = match.group(4)
|
127 |
+
result = ""
|
128 |
+
if year:
|
129 |
+
result += f"{verbalize_digit(year)}年"
|
130 |
+
if month:
|
131 |
+
result += f"{verbalize_cardinal(month)}月"
|
132 |
+
if day:
|
133 |
+
result += f"{verbalize_cardinal(day)}日"
|
134 |
+
return result
|
SongBloom/g2p/cn_zh_g2p/zh_normalization/constants.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import re
|
15 |
+
import string
|
16 |
+
|
17 |
+
from pypinyin.constants import SUPPORT_UCS4
|
18 |
+
|
19 |
+
# 全角半角转换
|
20 |
+
# 英文字符全角 -> 半角映射表 (num: 52)
|
21 |
+
F2H_ASCII_LETTERS = {
|
22 |
+
ord(char) + 65248: ord(char)
|
23 |
+
for char in string.ascii_letters
|
24 |
+
}
|
25 |
+
|
26 |
+
# 英文字符半角 -> 全角映射表
|
27 |
+
H2F_ASCII_LETTERS = {value: key for key, value in F2H_ASCII_LETTERS.items()}
|
28 |
+
|
29 |
+
# 数字字符全角 -> 半角映射表 (num: 10)
|
30 |
+
F2H_DIGITS = {ord(char) + 65248: ord(char) for char in string.digits}
|
31 |
+
# 数字字符半角 -> 全角映射表
|
32 |
+
H2F_DIGITS = {value: key for key, value in F2H_DIGITS.items()}
|
33 |
+
|
34 |
+
# 标点符号全角 -> 半角映射表 (num: 32)
|
35 |
+
F2H_PUNCTUATIONS = {ord(char) + 65248: ord(char) for char in string.punctuation}
|
36 |
+
# 标点符号半角 -> 全角映射表
|
37 |
+
H2F_PUNCTUATIONS = {value: key for key, value in F2H_PUNCTUATIONS.items()}
|
38 |
+
|
39 |
+
# 空格 (num: 1)
|
40 |
+
F2H_SPACE = {'\u3000': ' '}
|
41 |
+
H2F_SPACE = {' ': '\u3000'}
|
42 |
+
|
43 |
+
# 非"有拼音的汉字"的字符串,可用于NSW提取
|
44 |
+
if SUPPORT_UCS4:
|
45 |
+
RE_NSW = re.compile(r'(?:[^'
|
46 |
+
r'\u3007' # 〇
|
47 |
+
r'\u3400-\u4dbf' # CJK扩展A:[3400-4DBF]
|
48 |
+
r'\u4e00-\u9fff' # CJK基本:[4E00-9FFF]
|
49 |
+
r'\uf900-\ufaff' # CJK兼容:[F900-FAFF]
|
50 |
+
r'\U00020000-\U0002A6DF' # CJK扩展B:[20000-2A6DF]
|
51 |
+
r'\U0002A703-\U0002B73F' # CJK扩展C:[2A700-2B73F]
|
52 |
+
r'\U0002B740-\U0002B81D' # CJK扩展D:[2B740-2B81D]
|
53 |
+
r'\U0002F80A-\U0002FA1F' # CJK兼容扩展:[2F800-2FA1F]
|
54 |
+
r'])+')
|
55 |
+
else:
|
56 |
+
RE_NSW = re.compile( # pragma: no cover
|
57 |
+
r'(?:[^'
|
58 |
+
r'\u3007' # 〇
|
59 |
+
r'\u3400-\u4dbf' # CJK扩展A:[3400-4DBF]
|
60 |
+
r'\u4e00-\u9fff' # CJK基本:[4E00-9FFF]
|
61 |
+
r'\uf900-\ufaff' # CJK兼容:[F900-FAFF]
|
62 |
+
r'])+')
|
SongBloom/g2p/cn_zh_g2p/zh_normalization/num.py
ADDED
@@ -0,0 +1,282 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
"""
|
15 |
+
Rules to verbalize numbers into Chinese characters.
|
16 |
+
https://zh.wikipedia.org/wiki/中文数字#現代中文
|
17 |
+
"""
|
18 |
+
import re
|
19 |
+
from collections import OrderedDict
|
20 |
+
from typing import List
|
21 |
+
|
22 |
+
DIGITS = {str(i): tran for i, tran in enumerate('零一二三四五六七八九')}
|
23 |
+
UNITS = OrderedDict({
|
24 |
+
1: '十',
|
25 |
+
2: '百',
|
26 |
+
3: '千',
|
27 |
+
4: '万',
|
28 |
+
8: '亿',
|
29 |
+
})
|
30 |
+
|
31 |
+
COM_QUANTIFIERS = '(封|艘|把|目|套|段|人|所|朵|匹|张|座|回|场|尾|条|个|首|阙|阵|网|炮|顶|丘|棵|只|支|袭|辆|挑|担|颗|壳|窠|曲|墙|群|腔|砣|座|客|贯|扎|捆|刀|令|打|手|罗|坡|山|岭|江|溪|钟|队|单|双|对|出|口|头|脚|板|跳|枝|件|贴|针|线|管|名|位|身|堂|课|本|页|家|户|层|丝|毫|厘|分|钱|两|斤|担|铢|石|钧|锱|忽|(千|毫|微)克|毫|厘|(公)分|分|寸|尺|丈|里|寻|常|铺|程|(千|分|厘|毫|微)米|米|撮|勺|合|升|斗|石|盘|碗|碟|叠|桶|笼|盆|盒|杯|钟|斛|锅|簋|篮|盘|桶|罐|瓶|壶|卮|盏|箩|箱|煲|啖|袋|钵|年|月|日|季|刻|时|周|天|秒|分|小时|旬|纪|岁|世|更|夜|春|夏|秋|冬|代|伏|辈|丸|泡|粒|颗|幢|堆|条|根|支|道|面|片|张|颗|块|元|(亿|千万|百万|万|千|百)|(亿|千万|百万|万|千|百|美|)元|(亿|千万|百万|万|千|百|十|)吨|(亿|千万|百万|万|千|百|)块|角|毛|分)'
|
32 |
+
|
33 |
+
# 分数表达式
|
34 |
+
RE_FRAC = re.compile(r'(-?)(\d+)/(\d+)')
|
35 |
+
|
36 |
+
|
37 |
+
def replace_frac(match) -> str:
|
38 |
+
"""
|
39 |
+
Args:
|
40 |
+
match (re.Match)
|
41 |
+
Returns:
|
42 |
+
str
|
43 |
+
"""
|
44 |
+
sign = match.group(1)
|
45 |
+
nominator = match.group(2)
|
46 |
+
denominator = match.group(3)
|
47 |
+
sign: str = "负" if sign else ""
|
48 |
+
nominator: str = num2str(nominator)
|
49 |
+
denominator: str = num2str(denominator)
|
50 |
+
result = f"{sign}{denominator}分之{nominator}"
|
51 |
+
return result
|
52 |
+
|
53 |
+
|
54 |
+
# 百分数表达式
|
55 |
+
RE_PERCENTAGE = re.compile(r'(-?)(\d+(\.\d+)?)%')
|
56 |
+
|
57 |
+
|
58 |
+
def replace_percentage(match) -> str:
|
59 |
+
"""
|
60 |
+
Args:
|
61 |
+
match (re.Match)
|
62 |
+
Returns:
|
63 |
+
str
|
64 |
+
"""
|
65 |
+
sign = match.group(1)
|
66 |
+
percent = match.group(2)
|
67 |
+
sign: str = "负" if sign else ""
|
68 |
+
percent: str = num2str(percent)
|
69 |
+
result = f"{sign}百分之{percent}"
|
70 |
+
return result
|
71 |
+
|
72 |
+
|
73 |
+
# 整数表达式
|
74 |
+
# 带负号的整数 -10
|
75 |
+
RE_INTEGER = re.compile(r'(-)' r'(\d+)')
|
76 |
+
|
77 |
+
|
78 |
+
def replace_negative_num(match) -> str:
|
79 |
+
"""
|
80 |
+
Args:
|
81 |
+
match (re.Match)
|
82 |
+
Returns:
|
83 |
+
str
|
84 |
+
"""
|
85 |
+
sign = match.group(1)
|
86 |
+
number = match.group(2)
|
87 |
+
sign: str = "负" if sign else ""
|
88 |
+
number: str = num2str(number)
|
89 |
+
result = f"{sign}{number}"
|
90 |
+
return result
|
91 |
+
|
92 |
+
|
93 |
+
# 编号-无符号整形
|
94 |
+
# 00078
|
95 |
+
RE_DEFAULT_NUM = re.compile(r'\d{3}\d*')
|
96 |
+
|
97 |
+
|
98 |
+
def replace_default_num(match):
|
99 |
+
"""
|
100 |
+
Args:
|
101 |
+
match (re.Match)
|
102 |
+
Returns:
|
103 |
+
str
|
104 |
+
"""
|
105 |
+
number = match.group(0)
|
106 |
+
return verbalize_digit(number, alt_one=True)
|
107 |
+
|
108 |
+
|
109 |
+
# 加减乘除
|
110 |
+
RE_ASMD = re.compile(
|
111 |
+
r'((-?)((\d+)(\.\d+)?)|(\.(\d+)))([\+\-\×÷=])((-?)((\d+)(\.\d+)?)|(\.(\d+)))')
|
112 |
+
asmd_map = {
|
113 |
+
'+': '加',
|
114 |
+
'-': '减',
|
115 |
+
'×': '乘',
|
116 |
+
'÷': '除',
|
117 |
+
'=': '等于'
|
118 |
+
}
|
119 |
+
|
120 |
+
|
121 |
+
def replace_asmd(match) -> str:
|
122 |
+
"""
|
123 |
+
Args:
|
124 |
+
match (re.Match)
|
125 |
+
Returns:
|
126 |
+
str
|
127 |
+
"""
|
128 |
+
result = match.group(1) + asmd_map[match.group(8)] + match.group(9)
|
129 |
+
return result
|
130 |
+
|
131 |
+
|
132 |
+
# 数字表达式
|
133 |
+
# 纯小数
|
134 |
+
RE_DECIMAL_NUM = re.compile(r'(-?)((\d+)(\.\d+))' r'|(\.(\d+))')
|
135 |
+
# 正整数 + 量词
|
136 |
+
RE_POSITIVE_QUANTIFIERS = re.compile(r"(\d+)([多余几\+])?" + COM_QUANTIFIERS)
|
137 |
+
RE_NUMBER = re.compile(r'(-?)((\d+)(\.\d+)?)' r'|(\.(\d+))')
|
138 |
+
|
139 |
+
|
140 |
+
def replace_positive_quantifier(match) -> str:
|
141 |
+
"""
|
142 |
+
Args:
|
143 |
+
match (re.Match)
|
144 |
+
Returns:
|
145 |
+
str
|
146 |
+
"""
|
147 |
+
number = match.group(1)
|
148 |
+
match_2 = match.group(2)
|
149 |
+
if match_2 == "+":
|
150 |
+
match_2 = "多"
|
151 |
+
match_2: str = match_2 if match_2 else ""
|
152 |
+
quantifiers: str = match.group(3)
|
153 |
+
number: str = num2str(number)
|
154 |
+
result = f"{number}{match_2}{quantifiers}"
|
155 |
+
return result
|
156 |
+
|
157 |
+
|
158 |
+
def replace_number(match) -> str:
|
159 |
+
"""
|
160 |
+
Args:
|
161 |
+
match (re.Match)
|
162 |
+
Returns:
|
163 |
+
str
|
164 |
+
"""
|
165 |
+
sign = match.group(1)
|
166 |
+
number = match.group(2)
|
167 |
+
pure_decimal = match.group(5)
|
168 |
+
if pure_decimal:
|
169 |
+
result = num2str(pure_decimal)
|
170 |
+
else:
|
171 |
+
sign: str = "负" if sign else ""
|
172 |
+
number: str = num2str(number)
|
173 |
+
result = f"{sign}{number}"
|
174 |
+
return result
|
175 |
+
|
176 |
+
|
177 |
+
# 范围表达式
|
178 |
+
# match.group(1) and match.group(8) are copy from RE_NUMBER
|
179 |
+
|
180 |
+
RE_RANGE = re.compile(
|
181 |
+
r"""
|
182 |
+
(?<![\d\+\-\×÷=]) # 使用反向前瞻以确保数字范围之前没有其他数字和操作符
|
183 |
+
((-?)((\d+)(\.\d+)?)) # 匹配范围起始的负数或正数(整数或小数)
|
184 |
+
[-~] # 匹配范围分隔符
|
185 |
+
((-?)((\d+)(\.\d+)?)) # 匹配范围结束的负数或正数(整数或小数)
|
186 |
+
(?![\d\+\-\×÷=]) # 使用正向前瞻以确保数字范围之后没有其他数字和操作符
|
187 |
+
""", re.VERBOSE)
|
188 |
+
|
189 |
+
|
190 |
+
def replace_range(match) -> str:
|
191 |
+
"""
|
192 |
+
Args:
|
193 |
+
match (re.Match)
|
194 |
+
Returns:
|
195 |
+
str
|
196 |
+
"""
|
197 |
+
first, second = match.group(1), match.group(6)
|
198 |
+
first = RE_NUMBER.sub(replace_number, first)
|
199 |
+
second = RE_NUMBER.sub(replace_number, second)
|
200 |
+
result = f"{first}到{second}"
|
201 |
+
return result
|
202 |
+
|
203 |
+
|
204 |
+
# ~至表达式
|
205 |
+
RE_TO_RANGE = re.compile(
|
206 |
+
r'((-?)((\d+)(\.\d+)?)|(\.(\d+)))(%|°C|℃|度|摄氏度|cm2|cm²|cm3|cm³|cm|db|ds|kg|km|m2|m²|m³|m3|ml|m|mm|s)[~]((-?)((\d+)(\.\d+)?)|(\.(\d+)))(%|°C|℃|度|摄氏度|cm2|cm²|cm3|cm³|cm|db|ds|kg|km|m2|m²|m³|m3|ml|m|mm|s)')
|
207 |
+
|
208 |
+
def replace_to_range(match) -> str:
|
209 |
+
"""
|
210 |
+
Args:
|
211 |
+
match (re.Match)
|
212 |
+
Returns:
|
213 |
+
str
|
214 |
+
"""
|
215 |
+
result = match.group(0).replace('~', '至')
|
216 |
+
return result
|
217 |
+
|
218 |
+
|
219 |
+
def _get_value(value_string: str, use_zero: bool=True) -> List[str]:
|
220 |
+
stripped = value_string.lstrip('0')
|
221 |
+
if len(stripped) == 0:
|
222 |
+
return []
|
223 |
+
elif len(stripped) == 1:
|
224 |
+
if use_zero and len(stripped) < len(value_string):
|
225 |
+
return [DIGITS['0'], DIGITS[stripped]]
|
226 |
+
else:
|
227 |
+
return [DIGITS[stripped]]
|
228 |
+
else:
|
229 |
+
largest_unit = next(
|
230 |
+
power for power in reversed(UNITS.keys()) if power < len(stripped))
|
231 |
+
first_part = value_string[:-largest_unit]
|
232 |
+
second_part = value_string[-largest_unit:]
|
233 |
+
return _get_value(first_part) + [UNITS[largest_unit]] + _get_value(
|
234 |
+
second_part)
|
235 |
+
|
236 |
+
|
237 |
+
def verbalize_cardinal(value_string: str) -> str:
|
238 |
+
if not value_string:
|
239 |
+
return ''
|
240 |
+
|
241 |
+
# 000 -> '零' , 0 -> '零'
|
242 |
+
value_string = value_string.lstrip('0')
|
243 |
+
if len(value_string) == 0:
|
244 |
+
return DIGITS['0']
|
245 |
+
|
246 |
+
result_symbols = _get_value(value_string)
|
247 |
+
# verbalized number starting with '一十*' is abbreviated as `十*`
|
248 |
+
if len(result_symbols) >= 2 and result_symbols[0] == DIGITS[
|
249 |
+
'1'] and result_symbols[1] == UNITS[1]:
|
250 |
+
result_symbols = result_symbols[1:]
|
251 |
+
return ''.join(result_symbols)
|
252 |
+
|
253 |
+
|
254 |
+
def verbalize_digit(value_string: str, alt_one=False) -> str:
|
255 |
+
result_symbols = [DIGITS[digit] for digit in value_string]
|
256 |
+
result = ''.join(result_symbols)
|
257 |
+
if alt_one:
|
258 |
+
result = result.replace("一", "幺")
|
259 |
+
return result
|
260 |
+
|
261 |
+
|
262 |
+
def num2str(value_string: str) -> str:
|
263 |
+
integer_decimal = value_string.split('.')
|
264 |
+
if len(integer_decimal) == 1:
|
265 |
+
integer = integer_decimal[0]
|
266 |
+
decimal = ''
|
267 |
+
elif len(integer_decimal) == 2:
|
268 |
+
integer, decimal = integer_decimal
|
269 |
+
else:
|
270 |
+
raise ValueError(
|
271 |
+
f"The value string: '${value_string}' has more than one point in it."
|
272 |
+
)
|
273 |
+
|
274 |
+
result = verbalize_cardinal(integer)
|
275 |
+
|
276 |
+
decimal = decimal.rstrip('0')
|
277 |
+
if decimal:
|
278 |
+
# '.22' is verbalized as '零点二二'
|
279 |
+
# '3.20' is verbalized as '三点二
|
280 |
+
result = result if result else "零"
|
281 |
+
result += '点' + verbalize_digit(decimal)
|
282 |
+
return result
|
SongBloom/g2p/cn_zh_g2p/zh_normalization/phonecode.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import re
|
15 |
+
|
16 |
+
from .num import verbalize_digit
|
17 |
+
|
18 |
+
# 规范化固话/手机号码
|
19 |
+
# 手机
|
20 |
+
# http://www.jihaoba.com/news/show/13680
|
21 |
+
# 移动:139、138、137、136、135、134、159、158、157、150、151、152、188、187、182、183、184、178、198
|
22 |
+
# 联通:130、131、132、156、155、186、185、176
|
23 |
+
# 电信:133、153、189、180、181、177
|
24 |
+
RE_MOBILE_PHONE = re.compile(
|
25 |
+
r"(?<!\d)((\+?86 ?)?1([38]\d|5[0-35-9]|7[678]|9[89])\d{8})(?!\d)")
|
26 |
+
RE_TELEPHONE = re.compile(
|
27 |
+
r"(?<!\d)((0(10|2[1-3]|[3-9]\d{2})-?)?[1-9]\d{6,7})(?!\d)")
|
28 |
+
|
29 |
+
# 全国统一的号码400开头
|
30 |
+
RE_NATIONAL_UNIFORM_NUMBER = re.compile(r"(400)(-)?\d{3}(-)?\d{4}")
|
31 |
+
|
32 |
+
|
33 |
+
def phone2str(phone_string: str, mobile=True) -> str:
|
34 |
+
if mobile:
|
35 |
+
sp_parts = phone_string.strip('+').split()
|
36 |
+
result = ','.join(
|
37 |
+
[verbalize_digit(part, alt_one=True) for part in sp_parts])
|
38 |
+
return result
|
39 |
+
else:
|
40 |
+
sil_parts = phone_string.split('-')
|
41 |
+
result = ','.join(
|
42 |
+
[verbalize_digit(part, alt_one=True) for part in sil_parts])
|
43 |
+
return result
|
44 |
+
|
45 |
+
|
46 |
+
def replace_phone(match) -> str:
|
47 |
+
"""
|
48 |
+
Args:
|
49 |
+
match (re.Match)
|
50 |
+
Returns:
|
51 |
+
str
|
52 |
+
"""
|
53 |
+
return phone2str(match.group(0), mobile=False)
|
54 |
+
|
55 |
+
|
56 |
+
def replace_mobile(match) -> str:
|
57 |
+
"""
|
58 |
+
Args:
|
59 |
+
match (re.Match)
|
60 |
+
Returns:
|
61 |
+
str
|
62 |
+
"""
|
63 |
+
return phone2str(match.group(0))
|
SongBloom/g2p/cn_zh_g2p/zh_normalization/quantifier.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import re
|
15 |
+
|
16 |
+
from .num import num2str
|
17 |
+
|
18 |
+
# 温度表达式,温度会影响负号的读法
|
19 |
+
# -3°C 零下三度
|
20 |
+
RE_TEMPERATURE = re.compile(r'(-?)(\d+(\.\d+)?)(°C|℃|度|摄氏度)')
|
21 |
+
measure_dict = {
|
22 |
+
"cm2": "平方厘米",
|
23 |
+
"cm²": "平方厘米",
|
24 |
+
"cm3": "立方厘米",
|
25 |
+
"cm³": "立方厘米",
|
26 |
+
"cm": "厘米",
|
27 |
+
"db": "分贝",
|
28 |
+
"ds": "毫秒",
|
29 |
+
"kg": "千克",
|
30 |
+
"km": "千米",
|
31 |
+
"m2": "平方米",
|
32 |
+
"m²": "平方米",
|
33 |
+
"m³": "立方米",
|
34 |
+
"m3": "立方米",
|
35 |
+
"ml": "毫升",
|
36 |
+
"m": "米",
|
37 |
+
"mm": "毫米",
|
38 |
+
"s": "秒"
|
39 |
+
}
|
40 |
+
|
41 |
+
|
42 |
+
def replace_temperature(match) -> str:
|
43 |
+
"""
|
44 |
+
Args:
|
45 |
+
match (re.Match)
|
46 |
+
Returns:
|
47 |
+
str
|
48 |
+
"""
|
49 |
+
sign = match.group(1)
|
50 |
+
temperature = match.group(2)
|
51 |
+
unit = match.group(3)
|
52 |
+
sign: str = "零下" if sign else ""
|
53 |
+
temperature: str = num2str(temperature)
|
54 |
+
unit: str = "摄氏度" if unit == "摄氏度" else "度"
|
55 |
+
result = f"{sign}{temperature}{unit}"
|
56 |
+
return result
|
57 |
+
|
58 |
+
|
59 |
+
def replace_measure(sentence) -> str:
|
60 |
+
for q_notation in measure_dict:
|
61 |
+
if q_notation in sentence:
|
62 |
+
sentence = sentence.replace(q_notation, measure_dict[q_notation])
|
63 |
+
return sentence
|
SongBloom/g2p/cn_zh_g2p/zh_normalization/text_normlization.py
ADDED
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import re
|
15 |
+
from typing import List
|
16 |
+
|
17 |
+
from .char_convert import tranditional_to_simplified
|
18 |
+
from .chronology import RE_DATE
|
19 |
+
from .chronology import RE_DATE2
|
20 |
+
from .chronology import RE_TIME
|
21 |
+
from .chronology import RE_TIME_RANGE
|
22 |
+
from .chronology import replace_date
|
23 |
+
from .chronology import replace_date2
|
24 |
+
from .chronology import replace_time
|
25 |
+
from .constants import F2H_ASCII_LETTERS
|
26 |
+
from .constants import F2H_DIGITS
|
27 |
+
from .constants import F2H_SPACE
|
28 |
+
from .num import RE_DECIMAL_NUM
|
29 |
+
from .num import RE_DEFAULT_NUM
|
30 |
+
from .num import RE_FRAC
|
31 |
+
from .num import RE_INTEGER
|
32 |
+
from .num import RE_NUMBER
|
33 |
+
from .num import RE_PERCENTAGE
|
34 |
+
from .num import RE_POSITIVE_QUANTIFIERS
|
35 |
+
from .num import RE_RANGE
|
36 |
+
from .num import RE_TO_RANGE
|
37 |
+
from .num import RE_ASMD
|
38 |
+
from .num import replace_default_num
|
39 |
+
from .num import replace_frac
|
40 |
+
from .num import replace_negative_num
|
41 |
+
from .num import replace_number
|
42 |
+
from .num import replace_percentage
|
43 |
+
from .num import replace_positive_quantifier
|
44 |
+
from .num import replace_range
|
45 |
+
from .num import replace_to_range
|
46 |
+
from .num import replace_asmd
|
47 |
+
from .phonecode import RE_MOBILE_PHONE
|
48 |
+
from .phonecode import RE_NATIONAL_UNIFORM_NUMBER
|
49 |
+
from .phonecode import RE_TELEPHONE
|
50 |
+
from .phonecode import replace_mobile
|
51 |
+
from .phonecode import replace_phone
|
52 |
+
from .quantifier import RE_TEMPERATURE
|
53 |
+
from .quantifier import replace_measure
|
54 |
+
from .quantifier import replace_temperature
|
55 |
+
|
56 |
+
|
57 |
+
class TextNormalizer():
|
58 |
+
def __init__(self):
|
59 |
+
self.SENTENCE_SPLITOR = re.compile(r'([:、,;。?!,;?!][”’]?)')
|
60 |
+
|
61 |
+
def _split(self, text: str, lang="zh") -> List[str]:
|
62 |
+
"""Split long text into sentences with sentence-splitting punctuations.
|
63 |
+
Args:
|
64 |
+
text (str): The input text.
|
65 |
+
Returns:
|
66 |
+
List[str]: Sentences.
|
67 |
+
"""
|
68 |
+
# Only for pure Chinese here
|
69 |
+
if lang == "zh":
|
70 |
+
text = text.replace(" ", "")
|
71 |
+
# 过滤掉特殊字符
|
72 |
+
text = re.sub(r'[——《》【】<>{}()()#&@“”^_|\\]', '', text)
|
73 |
+
text = self.SENTENCE_SPLITOR.sub(r'\1\n', text)
|
74 |
+
text = text.strip()
|
75 |
+
sentences = [sentence.strip() for sentence in re.split(r'\n+', text)]
|
76 |
+
return sentences
|
77 |
+
|
78 |
+
def _post_replace(self, sentence: str) -> str:
|
79 |
+
sentence = sentence.replace('/', '每')
|
80 |
+
# sentence = sentence.replace('~', '至')
|
81 |
+
# sentence = sentence.replace('~', '至')
|
82 |
+
sentence = sentence.replace('①', '一')
|
83 |
+
sentence = sentence.replace('②', '二')
|
84 |
+
sentence = sentence.replace('③', '三')
|
85 |
+
sentence = sentence.replace('④', '四')
|
86 |
+
sentence = sentence.replace('⑤', '五')
|
87 |
+
sentence = sentence.replace('⑥', '六')
|
88 |
+
sentence = sentence.replace('⑦', '七')
|
89 |
+
sentence = sentence.replace('⑧', '八')
|
90 |
+
sentence = sentence.replace('⑨', '九')
|
91 |
+
sentence = sentence.replace('⑩', '十')
|
92 |
+
sentence = sentence.replace('α', '阿尔法')
|
93 |
+
sentence = sentence.replace('β', '贝塔')
|
94 |
+
sentence = sentence.replace('γ', '伽玛').replace('Γ', '伽玛')
|
95 |
+
sentence = sentence.replace('δ', '德尔塔').replace('Δ', '德尔塔')
|
96 |
+
sentence = sentence.replace('ε', '艾普西龙')
|
97 |
+
sentence = sentence.replace('ζ', '捷塔')
|
98 |
+
sentence = sentence.replace('η', '依塔')
|
99 |
+
sentence = sentence.replace('θ', '西塔').replace('Θ', '西塔')
|
100 |
+
sentence = sentence.replace('ι', '艾欧塔')
|
101 |
+
sentence = sentence.replace('κ', '喀帕')
|
102 |
+
sentence = sentence.replace('λ', '拉姆达').replace('Λ', '拉姆达')
|
103 |
+
sentence = sentence.replace('μ', '缪')
|
104 |
+
sentence = sentence.replace('ν', '拗')
|
105 |
+
sentence = sentence.replace('ξ', '克西').replace('Ξ', '克西')
|
106 |
+
sentence = sentence.replace('ο', '欧米克伦')
|
107 |
+
sentence = sentence.replace('π', '派').replace('Π', '派')
|
108 |
+
sentence = sentence.replace('ρ', '肉')
|
109 |
+
sentence = sentence.replace('ς', '西格玛').replace('Σ', '西格玛').replace(
|
110 |
+
'σ', '西格玛')
|
111 |
+
sentence = sentence.replace('τ', '套')
|
112 |
+
sentence = sentence.replace('υ', '宇普西龙')
|
113 |
+
sentence = sentence.replace('φ', '服艾').replace('Φ', '服艾')
|
114 |
+
sentence = sentence.replace('χ', '器')
|
115 |
+
sentence = sentence.replace('ψ', '普赛').replace('Ψ', '普赛')
|
116 |
+
sentence = sentence.replace('ω', '欧米伽').replace('Ω', '欧米伽')
|
117 |
+
# re filter special characters, have one more character "-" than line 68
|
118 |
+
sentence = re.sub(r'[-——《》【】<=>{}()()#&@“”^_|\\]', '', sentence)
|
119 |
+
return sentence
|
120 |
+
|
121 |
+
def normalize_sentence(self, sentence: str) -> str:
|
122 |
+
# basic character conversions
|
123 |
+
sentence = tranditional_to_simplified(sentence)
|
124 |
+
sentence = sentence.translate(F2H_ASCII_LETTERS).translate(
|
125 |
+
F2H_DIGITS).translate(F2H_SPACE)
|
126 |
+
|
127 |
+
# number related NSW verbalization
|
128 |
+
sentence = RE_DATE.sub(replace_date, sentence)
|
129 |
+
sentence = RE_DATE2.sub(replace_date2, sentence)
|
130 |
+
|
131 |
+
# range first
|
132 |
+
sentence = RE_TIME_RANGE.sub(replace_time, sentence)
|
133 |
+
sentence = RE_TIME.sub(replace_time, sentence)
|
134 |
+
|
135 |
+
# 处理~波浪号作为至的替换
|
136 |
+
sentence = RE_TO_RANGE.sub(replace_to_range, sentence)
|
137 |
+
sentence = RE_TEMPERATURE.sub(replace_temperature, sentence)
|
138 |
+
sentence = replace_measure(sentence)
|
139 |
+
sentence = RE_FRAC.sub(replace_frac, sentence)
|
140 |
+
sentence = RE_PERCENTAGE.sub(replace_percentage, sentence)
|
141 |
+
sentence = RE_MOBILE_PHONE.sub(replace_mobile, sentence)
|
142 |
+
|
143 |
+
sentence = RE_TELEPHONE.sub(replace_phone, sentence)
|
144 |
+
sentence = RE_NATIONAL_UNIFORM_NUMBER.sub(replace_phone, sentence)
|
145 |
+
|
146 |
+
sentence = RE_RANGE.sub(replace_range, sentence)
|
147 |
+
|
148 |
+
# 处理加减乘除
|
149 |
+
while RE_ASMD.search(sentence):
|
150 |
+
sentence = RE_ASMD.sub(replace_asmd, sentence)
|
151 |
+
|
152 |
+
sentence = RE_INTEGER.sub(replace_negative_num, sentence)
|
153 |
+
sentence = RE_DECIMAL_NUM.sub(replace_number, sentence)
|
154 |
+
sentence = RE_POSITIVE_QUANTIFIERS.sub(replace_positive_quantifier,
|
155 |
+
sentence)
|
156 |
+
sentence = RE_DEFAULT_NUM.sub(replace_default_num, sentence)
|
157 |
+
sentence = RE_NUMBER.sub(replace_number, sentence)
|
158 |
+
sentence = self._post_replace(sentence)
|
159 |
+
|
160 |
+
return sentence
|
161 |
+
|
162 |
+
def normalize(self, text: str) -> List[str]:
|
163 |
+
sentences = self._split(text)
|
164 |
+
sentences = [self.normalize_sentence(sent) for sent in sentences]
|
165 |
+
return sentences
|
SongBloom/g2p/lyric_common.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, sys
|
2 |
+
|
3 |
+
sys.path.insert(0, os.path.dirname(__file__))
|
4 |
+
from pinyin.pinyin import G2P_PinYin
|
5 |
+
from cn_zh_g2p import G2P_Mix, symbols
|
6 |
+
|
7 |
+
key2processor = {
|
8 |
+
'pinyin': G2P_PinYin(),
|
9 |
+
'phoneme': G2P_Mix(),
|
10 |
+
}
|
11 |
+
|
12 |
+
valid_struct_type = ['[chorus]', '[verse]', '[bridge]']
|
13 |
+
start_struct_type = ['[intro]', '[start]']
|
14 |
+
end_struct_type = ['[outro]', '[end]']
|
15 |
+
conn_struct_type = ['[inst]', '[solo]', '[break]']
|
16 |
+
|
17 |
+
LABELS = {
|
18 |
+
'[intro]': 0,
|
19 |
+
'[outro]': 1,
|
20 |
+
'[bridge]': 2,
|
21 |
+
'[inst]': 3,
|
22 |
+
'[verse]': 4,
|
23 |
+
'[chorus]': 5,
|
24 |
+
'[silence]': 6,
|
25 |
+
}
|
26 |
+
|
27 |
+
NUMBERS = {
|
28 |
+
'0': ['零', 'zero'],
|
29 |
+
'1': ['一', 'one'],
|
30 |
+
'2': ['二', 'two'],
|
31 |
+
'3': ['三', 'three'],
|
32 |
+
'4': ['四', 'four'],
|
33 |
+
'5': ['五', 'five'],
|
34 |
+
'6': ['六', 'six'],
|
35 |
+
'7': ['七', 'seven'],
|
36 |
+
'8': ['八', 'eight'],
|
37 |
+
'9': ['九', 'nine']
|
38 |
+
}
|
39 |
+
|
40 |
+
def detect_structure(structure):
|
41 |
+
valid_start = ['start', 'intro']
|
42 |
+
valid_end = ['outro', 'end']
|
43 |
+
valid_instru = ['solo', 'inst', 'break']
|
44 |
+
valid_bridge = ['bridge']
|
45 |
+
|
46 |
+
if structure in ['verse', 'chorus', 'silence']:
|
47 |
+
return structure
|
48 |
+
|
49 |
+
if structure in valid_start:
|
50 |
+
return 'intro'
|
51 |
+
if structure in valid_end:
|
52 |
+
return 'outro'
|
53 |
+
if structure in valid_instru:
|
54 |
+
return 'inst'
|
55 |
+
if structure in valid_bridge:
|
56 |
+
return 'bridge'
|
57 |
+
|
58 |
+
def merge_structure(start_time, end_time, structure, lyric):
|
59 |
+
cnt = 1
|
60 |
+
while cnt < len(start_time):
|
61 |
+
if structure[cnt] == structure[cnt-1]:
|
62 |
+
end_time[cnt-1] = end_time[cnt]
|
63 |
+
if structure[cnt] not in ["verse", "chorus", "bridge"]:
|
64 |
+
del start_time[cnt]
|
65 |
+
del end_time[cnt]
|
66 |
+
del structure[cnt]
|
67 |
+
del lyric[cnt]
|
68 |
+
else:
|
69 |
+
cnt += 1
|
70 |
+
else:
|
71 |
+
cnt += 1
|
72 |
+
|
73 |
+
return start_time, end_time, structure, lyric
|
74 |
+
|
75 |
+
|
76 |
+
def is_struct_legal(struct, text):
|
77 |
+
if struct in valid_struct_type and text != "":
|
78 |
+
return True
|
79 |
+
elif struct not in valid_struct_type and text == "":
|
80 |
+
return True
|
81 |
+
return False
|
SongBloom/g2p/pinyin/__init__.py
ADDED
@@ -0,0 +1,430 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .symbols import symbols
|
2 |
+
|
3 |
+
|
4 |
+
|
5 |
+
|
6 |
+
pinyin_dict = {
|
7 |
+
"a": ("^", "a"),
|
8 |
+
"ai": ("^", "ai"),
|
9 |
+
"an": ("^", "an"),
|
10 |
+
"ang": ("^", "ang"),
|
11 |
+
"ao": ("^", "ao"),
|
12 |
+
"ba": ("b", "a"),
|
13 |
+
"bai": ("b", "ai"),
|
14 |
+
"ban": ("b", "an"),
|
15 |
+
"bang": ("b", "ang"),
|
16 |
+
"bao": ("b", "ao"),
|
17 |
+
"be": ("b", "e"),
|
18 |
+
"bei": ("b", "ei"),
|
19 |
+
"ben": ("b", "en"),
|
20 |
+
"beng": ("b", "eng"),
|
21 |
+
"bi": ("b", "i"),
|
22 |
+
"bian": ("b", "ian"),
|
23 |
+
"biao": ("b", "iao"),
|
24 |
+
"bie": ("b", "ie"),
|
25 |
+
"bin": ("b", "in"),
|
26 |
+
"bing": ("b", "ing"),
|
27 |
+
"bo": ("b", "o"),
|
28 |
+
"bu": ("b", "u"),
|
29 |
+
"ca": ("c", "a"),
|
30 |
+
"cai": ("c", "ai"),
|
31 |
+
"can": ("c", "an"),
|
32 |
+
"cang": ("c", "ang"),
|
33 |
+
"cao": ("c", "ao"),
|
34 |
+
"ce": ("c", "e"),
|
35 |
+
"cen": ("c", "en"),
|
36 |
+
"ceng": ("c", "eng"),
|
37 |
+
"cha": ("ch", "a"),
|
38 |
+
"chai": ("ch", "ai"),
|
39 |
+
"chan": ("ch", "an"),
|
40 |
+
"chang": ("ch", "ang"),
|
41 |
+
"chao": ("ch", "ao"),
|
42 |
+
"che": ("ch", "e"),
|
43 |
+
"chen": ("ch", "en"),
|
44 |
+
"cheng": ("ch", "eng"),
|
45 |
+
"chi": ("ch", "iii"),
|
46 |
+
"chong": ("ch", "ong"),
|
47 |
+
"chou": ("ch", "ou"),
|
48 |
+
"chu": ("ch", "u"),
|
49 |
+
"chua": ("ch", "ua"),
|
50 |
+
"chuai": ("ch", "uai"),
|
51 |
+
"chuan": ("ch", "uan"),
|
52 |
+
"chuang": ("ch", "uang"),
|
53 |
+
"chui": ("ch", "uei"),
|
54 |
+
"chun": ("ch", "uen"),
|
55 |
+
"chuo": ("ch", "uo"),
|
56 |
+
"ci": ("c", "ii"),
|
57 |
+
"cong": ("c", "ong"),
|
58 |
+
"cou": ("c", "ou"),
|
59 |
+
"cu": ("c", "u"),
|
60 |
+
"cuan": ("c", "uan"),
|
61 |
+
"cui": ("c", "uei"),
|
62 |
+
"cun": ("c", "uen"),
|
63 |
+
"cuo": ("c", "uo"),
|
64 |
+
"da": ("d", "a"),
|
65 |
+
"dai": ("d", "ai"),
|
66 |
+
"dan": ("d", "an"),
|
67 |
+
"dang": ("d", "ang"),
|
68 |
+
"dao": ("d", "ao"),
|
69 |
+
"de": ("d", "e"),
|
70 |
+
"dei": ("d", "ei"),
|
71 |
+
"den": ("d", "en"),
|
72 |
+
"deng": ("d", "eng"),
|
73 |
+
"di": ("d", "i"),
|
74 |
+
"dia": ("d", "ia"),
|
75 |
+
"dian": ("d", "ian"),
|
76 |
+
"diao": ("d", "iao"),
|
77 |
+
"die": ("d", "ie"),
|
78 |
+
"ding": ("d", "ing"),
|
79 |
+
"diu": ("d", "iou"),
|
80 |
+
"dong": ("d", "ong"),
|
81 |
+
"dou": ("d", "ou"),
|
82 |
+
"du": ("d", "u"),
|
83 |
+
"duan": ("d", "uan"),
|
84 |
+
"dui": ("d", "uei"),
|
85 |
+
"dun": ("d", "uen"),
|
86 |
+
"duo": ("d", "uo"),
|
87 |
+
"e": ("^", "e"),
|
88 |
+
"ei": ("^", "ei"),
|
89 |
+
"en": ("^", "en"),
|
90 |
+
"ng": ("^", "en"),
|
91 |
+
"eng": ("^", "eng"),
|
92 |
+
"er": ("^", "er"),
|
93 |
+
"fa": ("f", "a"),
|
94 |
+
"fan": ("f", "an"),
|
95 |
+
"fang": ("f", "ang"),
|
96 |
+
"fei": ("f", "ei"),
|
97 |
+
"fen": ("f", "en"),
|
98 |
+
"feng": ("f", "eng"),
|
99 |
+
"fo": ("f", "o"),
|
100 |
+
"fou": ("f", "ou"),
|
101 |
+
"fu": ("f", "u"),
|
102 |
+
"ga": ("g", "a"),
|
103 |
+
"gai": ("g", "ai"),
|
104 |
+
"gan": ("g", "an"),
|
105 |
+
"gang": ("g", "ang"),
|
106 |
+
"gao": ("g", "ao"),
|
107 |
+
"ge": ("g", "e"),
|
108 |
+
"gei": ("g", "ei"),
|
109 |
+
"gen": ("g", "en"),
|
110 |
+
"geng": ("g", "eng"),
|
111 |
+
"gong": ("g", "ong"),
|
112 |
+
"gou": ("g", "ou"),
|
113 |
+
"gu": ("g", "u"),
|
114 |
+
"gua": ("g", "ua"),
|
115 |
+
"guai": ("g", "uai"),
|
116 |
+
"guan": ("g", "uan"),
|
117 |
+
"guang": ("g", "uang"),
|
118 |
+
"gui": ("g", "uei"),
|
119 |
+
"gun": ("g", "uen"),
|
120 |
+
"guo": ("g", "uo"),
|
121 |
+
"ha": ("h", "a"),
|
122 |
+
"hai": ("h", "ai"),
|
123 |
+
"han": ("h", "an"),
|
124 |
+
"hang": ("h", "ang"),
|
125 |
+
"hao": ("h", "ao"),
|
126 |
+
"he": ("h", "e"),
|
127 |
+
"hei": ("h", "ei"),
|
128 |
+
"hen": ("h", "en"),
|
129 |
+
"heng": ("h", "eng"),
|
130 |
+
"hong": ("h", "ong"),
|
131 |
+
"hou": ("h", "ou"),
|
132 |
+
"hu": ("h", "u"),
|
133 |
+
"hua": ("h", "ua"),
|
134 |
+
"huai": ("h", "uai"),
|
135 |
+
"huan": ("h", "uan"),
|
136 |
+
"huang": ("h", "uang"),
|
137 |
+
"hui": ("h", "uei"),
|
138 |
+
"hun": ("h", "uen"),
|
139 |
+
"huo": ("h", "uo"),
|
140 |
+
"ji": ("j", "i"),
|
141 |
+
"jia": ("j", "ia"),
|
142 |
+
"jian": ("j", "ian"),
|
143 |
+
"jiang": ("j", "iang"),
|
144 |
+
"jiao": ("j", "iao"),
|
145 |
+
"jie": ("j", "ie"),
|
146 |
+
"jin": ("j", "in"),
|
147 |
+
"jing": ("j", "ing"),
|
148 |
+
"jiong": ("j", "iong"),
|
149 |
+
"jiu": ("j", "iou"),
|
150 |
+
"ju": ("j", "v"),
|
151 |
+
"juan": ("j", "van"),
|
152 |
+
"jue": ("j", "ve"),
|
153 |
+
"jun": ("j", "vn"),
|
154 |
+
"ka": ("k", "a"),
|
155 |
+
"kai": ("k", "ai"),
|
156 |
+
"kan": ("k", "an"),
|
157 |
+
"kang": ("k", "ang"),
|
158 |
+
"kao": ("k", "ao"),
|
159 |
+
"ke": ("k", "e"),
|
160 |
+
"kei": ("k", "ei"),
|
161 |
+
"ken": ("k", "en"),
|
162 |
+
"keng": ("k", "eng"),
|
163 |
+
"kong": ("k", "ong"),
|
164 |
+
"kou": ("k", "ou"),
|
165 |
+
"ku": ("k", "u"),
|
166 |
+
"kua": ("k", "ua"),
|
167 |
+
"kuai": ("k", "uai"),
|
168 |
+
"kuan": ("k", "uan"),
|
169 |
+
"kuang": ("k", "uang"),
|
170 |
+
"kui": ("k", "uei"),
|
171 |
+
"kun": ("k", "uen"),
|
172 |
+
"kuo": ("k", "uo"),
|
173 |
+
"la": ("l", "a"),
|
174 |
+
"lai": ("l", "ai"),
|
175 |
+
"lan": ("l", "an"),
|
176 |
+
"lang": ("l", "ang"),
|
177 |
+
"lao": ("l", "ao"),
|
178 |
+
"le": ("l", "e"),
|
179 |
+
"lei": ("l", "ei"),
|
180 |
+
"leng": ("l", "eng"),
|
181 |
+
"li": ("l", "i"),
|
182 |
+
"lia": ("l", "ia"),
|
183 |
+
"lian": ("l", "ian"),
|
184 |
+
"liang": ("l", "iang"),
|
185 |
+
"liao": ("l", "iao"),
|
186 |
+
"lie": ("l", "ie"),
|
187 |
+
"lin": ("l", "in"),
|
188 |
+
"ling": ("l", "ing"),
|
189 |
+
"liu": ("l", "iou"),
|
190 |
+
"lo": ("l", "o"),
|
191 |
+
"long": ("l", "ong"),
|
192 |
+
"lou": ("l", "ou"),
|
193 |
+
"lu": ("l", "u"),
|
194 |
+
"lv": ("l", "v"),
|
195 |
+
"luan": ("l", "uan"),
|
196 |
+
"lve": ("l", "ve"),
|
197 |
+
"lue": ("l", "ve"),
|
198 |
+
"lun": ("l", "uen"),
|
199 |
+
"luo": ("l", "uo"),
|
200 |
+
"ma": ("m", "a"),
|
201 |
+
"mai": ("m", "ai"),
|
202 |
+
"man": ("m", "an"),
|
203 |
+
"mang": ("m", "ang"),
|
204 |
+
"mao": ("m", "ao"),
|
205 |
+
"me": ("m", "e"),
|
206 |
+
"mei": ("m", "ei"),
|
207 |
+
"men": ("m", "en"),
|
208 |
+
"meng": ("m", "eng"),
|
209 |
+
"mi": ("m", "i"),
|
210 |
+
"mian": ("m", "ian"),
|
211 |
+
"miao": ("m", "iao"),
|
212 |
+
"mie": ("m", "ie"),
|
213 |
+
"min": ("m", "in"),
|
214 |
+
"ming": ("m", "ing"),
|
215 |
+
"miu": ("m", "iou"),
|
216 |
+
"mo": ("m", "o"),
|
217 |
+
"mou": ("m", "ou"),
|
218 |
+
"mu": ("m", "u"),
|
219 |
+
"na": ("n", "a"),
|
220 |
+
"nai": ("n", "ai"),
|
221 |
+
"nan": ("n", "an"),
|
222 |
+
"nang": ("n", "ang"),
|
223 |
+
"nao": ("n", "ao"),
|
224 |
+
"ne": ("n", "e"),
|
225 |
+
"nei": ("n", "ei"),
|
226 |
+
"nen": ("n", "en"),
|
227 |
+
"neng": ("n", "eng"),
|
228 |
+
"ni": ("n", "i"),
|
229 |
+
"nia": ("n", "ia"),
|
230 |
+
"nian": ("n", "ian"),
|
231 |
+
"niang": ("n", "iang"),
|
232 |
+
"niao": ("n", "iao"),
|
233 |
+
"nie": ("n", "ie"),
|
234 |
+
"nin": ("n", "in"),
|
235 |
+
"ning": ("n", "ing"),
|
236 |
+
"niu": ("n", "iou"),
|
237 |
+
"nong": ("n", "ong"),
|
238 |
+
"nou": ("n", "ou"),
|
239 |
+
"nu": ("n", "u"),
|
240 |
+
"nv": ("n", "v"),
|
241 |
+
"nuan": ("n", "uan"),
|
242 |
+
"nve": ("n", "ve"),
|
243 |
+
"nue": ("n", "ve"),
|
244 |
+
"nuo": ("n", "uo"),
|
245 |
+
"o": ("^", "o"),
|
246 |
+
"ou": ("^", "ou"),
|
247 |
+
"pa": ("p", "a"),
|
248 |
+
"pai": ("p", "ai"),
|
249 |
+
"pan": ("p", "an"),
|
250 |
+
"pang": ("p", "ang"),
|
251 |
+
"pao": ("p", "ao"),
|
252 |
+
"pe": ("p", "e"),
|
253 |
+
"pei": ("p", "ei"),
|
254 |
+
"pen": ("p", "en"),
|
255 |
+
"peng": ("p", "eng"),
|
256 |
+
"pi": ("p", "i"),
|
257 |
+
"pian": ("p", "ian"),
|
258 |
+
"piao": ("p", "iao"),
|
259 |
+
"pie": ("p", "ie"),
|
260 |
+
"pin": ("p", "in"),
|
261 |
+
"ping": ("p", "ing"),
|
262 |
+
"po": ("p", "o"),
|
263 |
+
"pou": ("p", "ou"),
|
264 |
+
"pu": ("p", "u"),
|
265 |
+
"qi": ("q", "i"),
|
266 |
+
"qia": ("q", "ia"),
|
267 |
+
"qian": ("q", "ian"),
|
268 |
+
"qiang": ("q", "iang"),
|
269 |
+
"qiao": ("q", "iao"),
|
270 |
+
"qie": ("q", "ie"),
|
271 |
+
"qin": ("q", "in"),
|
272 |
+
"qing": ("q", "ing"),
|
273 |
+
"qiong": ("q", "iong"),
|
274 |
+
"qiu": ("q", "iou"),
|
275 |
+
"qu": ("q", "v"),
|
276 |
+
"quan": ("q", "van"),
|
277 |
+
"que": ("q", "ve"),
|
278 |
+
"qun": ("q", "vn"),
|
279 |
+
"ran": ("r", "an"),
|
280 |
+
"rang": ("r", "ang"),
|
281 |
+
"rao": ("r", "ao"),
|
282 |
+
"re": ("r", "e"),
|
283 |
+
"ren": ("r", "en"),
|
284 |
+
"reng": ("r", "eng"),
|
285 |
+
"ri": ("r", "iii"),
|
286 |
+
"rong": ("r", "ong"),
|
287 |
+
"rou": ("r", "ou"),
|
288 |
+
"ru": ("r", "u"),
|
289 |
+
"rua": ("r", "ua"),
|
290 |
+
"ruan": ("r", "uan"),
|
291 |
+
"rui": ("r", "uei"),
|
292 |
+
"run": ("r", "uen"),
|
293 |
+
"ruo": ("r", "uo"),
|
294 |
+
"sa": ("s", "a"),
|
295 |
+
"sai": ("s", "ai"),
|
296 |
+
"san": ("s", "an"),
|
297 |
+
"sang": ("s", "ang"),
|
298 |
+
"sao": ("s", "ao"),
|
299 |
+
"se": ("s", "e"),
|
300 |
+
"sen": ("s", "en"),
|
301 |
+
"seng": ("s", "eng"),
|
302 |
+
"sha": ("sh", "a"),
|
303 |
+
"shai": ("sh", "ai"),
|
304 |
+
"shan": ("sh", "an"),
|
305 |
+
"shang": ("sh", "ang"),
|
306 |
+
"shao": ("sh", "ao"),
|
307 |
+
"she": ("sh", "e"),
|
308 |
+
"shei": ("sh", "ei"),
|
309 |
+
"shen": ("sh", "en"),
|
310 |
+
"sheng": ("sh", "eng"),
|
311 |
+
"shi": ("sh", "iii"),
|
312 |
+
"shou": ("sh", "ou"),
|
313 |
+
"shu": ("sh", "u"),
|
314 |
+
"shua": ("sh", "ua"),
|
315 |
+
"shuai": ("sh", "uai"),
|
316 |
+
"shuan": ("sh", "uan"),
|
317 |
+
"shuang": ("sh", "uang"),
|
318 |
+
"shui": ("sh", "uei"),
|
319 |
+
"shun": ("sh", "uen"),
|
320 |
+
"shuo": ("sh", "uo"),
|
321 |
+
"si": ("s", "ii"),
|
322 |
+
"song": ("s", "ong"),
|
323 |
+
"sou": ("s", "ou"),
|
324 |
+
"su": ("s", "u"),
|
325 |
+
"suan": ("s", "uan"),
|
326 |
+
"sui": ("s", "uei"),
|
327 |
+
"sun": ("s", "uen"),
|
328 |
+
"suo": ("s", "uo"),
|
329 |
+
"ta": ("t", "a"),
|
330 |
+
"tai": ("t", "ai"),
|
331 |
+
"tan": ("t", "an"),
|
332 |
+
"tang": ("t", "ang"),
|
333 |
+
"tao": ("t", "ao"),
|
334 |
+
"te": ("t", "e"),
|
335 |
+
"tei": ("t", "ei"),
|
336 |
+
"teng": ("t", "eng"),
|
337 |
+
"ti": ("t", "i"),
|
338 |
+
"tian": ("t", "ian"),
|
339 |
+
"tiao": ("t", "iao"),
|
340 |
+
"tie": ("t", "ie"),
|
341 |
+
"ting": ("t", "ing"),
|
342 |
+
"tong": ("t", "ong"),
|
343 |
+
"tou": ("t", "ou"),
|
344 |
+
"tu": ("t", "u"),
|
345 |
+
"tuan": ("t", "uan"),
|
346 |
+
"tui": ("t", "uei"),
|
347 |
+
"tun": ("t", "uen"),
|
348 |
+
"tuo": ("t", "uo"),
|
349 |
+
"wa": ("^", "ua"),
|
350 |
+
"wai": ("^", "uai"),
|
351 |
+
"wan": ("^", "uan"),
|
352 |
+
"wang": ("^", "uang"),
|
353 |
+
"wei": ("^", "uei"),
|
354 |
+
"wen": ("^", "uen"),
|
355 |
+
"weng": ("^", "ueng"),
|
356 |
+
"wo": ("^", "uo"),
|
357 |
+
"wu": ("^", "u"),
|
358 |
+
"xi": ("x", "i"),
|
359 |
+
"xia": ("x", "ia"),
|
360 |
+
"xian": ("x", "ian"),
|
361 |
+
"xiang": ("x", "iang"),
|
362 |
+
"xiao": ("x", "iao"),
|
363 |
+
"xie": ("x", "ie"),
|
364 |
+
"xin": ("x", "in"),
|
365 |
+
"xing": ("x", "ing"),
|
366 |
+
"xiong": ("x", "iong"),
|
367 |
+
"xiu": ("x", "iou"),
|
368 |
+
"xu": ("x", "v"),
|
369 |
+
"xuan": ("x", "van"),
|
370 |
+
"xue": ("x", "ve"),
|
371 |
+
"xun": ("x", "vn"),
|
372 |
+
"ya": ("^", "ia"),
|
373 |
+
"yan": ("^", "ian"),
|
374 |
+
"yang": ("^", "iang"),
|
375 |
+
"yao": ("^", "iao"),
|
376 |
+
"ye": ("^", "ie"),
|
377 |
+
"yi": ("^", "i"),
|
378 |
+
"yin": ("^", "in"),
|
379 |
+
"ying": ("^", "ing"),
|
380 |
+
"yo": ("^", "iou"),
|
381 |
+
"yong": ("^", "iong"),
|
382 |
+
"you": ("^", "iou"),
|
383 |
+
"yu": ("^", "v"),
|
384 |
+
"yuan": ("^", "van"),
|
385 |
+
"yue": ("^", "ve"),
|
386 |
+
"yun": ("^", "vn"),
|
387 |
+
"za": ("z", "a"),
|
388 |
+
"zai": ("z", "ai"),
|
389 |
+
"zan": ("z", "an"),
|
390 |
+
"zang": ("z", "ang"),
|
391 |
+
"zao": ("z", "ao"),
|
392 |
+
"ze": ("z", "e"),
|
393 |
+
"zei": ("z", "ei"),
|
394 |
+
"zen": ("z", "en"),
|
395 |
+
"zeng": ("z", "eng"),
|
396 |
+
"zha": ("zh", "a"),
|
397 |
+
"zhai": ("zh", "ai"),
|
398 |
+
"zhan": ("zh", "an"),
|
399 |
+
"zhang": ("zh", "ang"),
|
400 |
+
"zhao": ("zh", "ao"),
|
401 |
+
"zhe": ("zh", "e"),
|
402 |
+
"zhei": ("zh", "ei"),
|
403 |
+
"zhen": ("zh", "en"),
|
404 |
+
"zheng": ("zh", "eng"),
|
405 |
+
"zhi": ("zh", "iii"),
|
406 |
+
"zhong": ("zh", "ong"),
|
407 |
+
"zhou": ("zh", "ou"),
|
408 |
+
"zhu": ("zh", "u"),
|
409 |
+
"zhua": ("zh", "ua"),
|
410 |
+
"zhuai": ("zh", "uai"),
|
411 |
+
"zhuan": ("zh", "uan"),
|
412 |
+
"zhuang": ("zh", "uang"),
|
413 |
+
"zhui": ("zh", "uei"),
|
414 |
+
"zhun": ("zh", "uen"),
|
415 |
+
"zhuo": ("zh", "uo"),
|
416 |
+
"zi": ("z", "ii"),
|
417 |
+
"zong": ("z", "ong"),
|
418 |
+
"zou": ("z", "ou"),
|
419 |
+
"zu": ("z", "u"),
|
420 |
+
"zuan": ("z", "uan"),
|
421 |
+
"zui": ("z", "uei"),
|
422 |
+
"zun": ("z", "uen"),
|
423 |
+
"zuo": ("z", "uo"),
|
424 |
+
}
|
425 |
+
|
426 |
+
|
427 |
+
def gen_vocabs():
|
428 |
+
import yaml
|
429 |
+
vocab = [f"<{c}{i}>" for c in list(pinyin_dict.keys()) for i in range(1,6)]
|
430 |
+
yaml.dump(vocab, open('./vocab.yaml', 'w'))
|
SongBloom/g2p/pinyin/pinyin.py
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
|
3 |
+
from pypinyin import Style
|
4 |
+
from pypinyin.contrib.neutral_tone import NeutralToneWith5Mixin
|
5 |
+
from pypinyin.converter import DefaultConverter
|
6 |
+
from pypinyin.core import Pinyin
|
7 |
+
|
8 |
+
from . import pinyin_dict
|
9 |
+
import torch
|
10 |
+
|
11 |
+
|
12 |
+
class MyConverter(NeutralToneWith5Mixin, DefaultConverter):
|
13 |
+
pass
|
14 |
+
|
15 |
+
|
16 |
+
def is_chinese(uchar):
|
17 |
+
if uchar >= u'\u4e00' and uchar <= u'\u9fa5':
|
18 |
+
return True
|
19 |
+
else:
|
20 |
+
return False
|
21 |
+
|
22 |
+
|
23 |
+
def clean_chinese(text: str):
|
24 |
+
text = text.strip()
|
25 |
+
text_clean = []
|
26 |
+
for char in text:
|
27 |
+
if (is_chinese(char)):
|
28 |
+
text_clean.append(char)
|
29 |
+
else:
|
30 |
+
if len(text_clean) > 1 and is_chinese(text_clean[-1]):
|
31 |
+
text_clean.append(',')
|
32 |
+
text_clean = ''.join(text_clean).strip(',')
|
33 |
+
return text_clean
|
34 |
+
|
35 |
+
|
36 |
+
class G2P_PinYin():
|
37 |
+
|
38 |
+
def __init__(self):
|
39 |
+
super(G2P_PinYin, self).__init__()
|
40 |
+
self.pinyin_parser = Pinyin(MyConverter())
|
41 |
+
|
42 |
+
def get_phoneme4pinyin(self, pinyins):
|
43 |
+
result = []
|
44 |
+
count_phone = []
|
45 |
+
for pinyin in pinyins:
|
46 |
+
if pinyin[:-1] in pinyin_dict:
|
47 |
+
tone = pinyin[-1]
|
48 |
+
a = pinyin[:-1]
|
49 |
+
a1, a2 = pinyin_dict[a]
|
50 |
+
result += [a1, a2 + tone]
|
51 |
+
count_phone.append(2)
|
52 |
+
return result, count_phone
|
53 |
+
|
54 |
+
# def chinese_to_phonemes(self, text):
|
55 |
+
# text = clean_chinese(text)
|
56 |
+
# phonemes = ["sil"]
|
57 |
+
# chars = ['[PAD]']
|
58 |
+
# all_pinyins = []
|
59 |
+
# count_phone = []
|
60 |
+
# count_phone.append(1)
|
61 |
+
# for subtext in text.split(","):
|
62 |
+
# if (len(subtext) == 0):
|
63 |
+
# continue
|
64 |
+
# pinyins = self.correct_pinyin_tone3(subtext)
|
65 |
+
# all_pinyins.append(' '.join(pinyins))
|
66 |
+
# sub_p, sub_c = self.get_phoneme4pinyin(pinyins)
|
67 |
+
# phonemes.extend(sub_p)
|
68 |
+
# phonemes.append(",")
|
69 |
+
# count_phone.extend(sub_c)
|
70 |
+
# count_phone.append(1)
|
71 |
+
# chars.append(subtext)
|
72 |
+
# chars.append(',')
|
73 |
+
# phonemes.append("sil")
|
74 |
+
# count_phone.append(1)
|
75 |
+
# chars.append('[PAD]')
|
76 |
+
# # char_embeds = self.prosody.expand_for_phone(char_embeds, count_phone)
|
77 |
+
# return " ".join(phonemes), " ".join(chars), ' , '.join(all_pinyins)
|
78 |
+
|
79 |
+
def chinese_to_phonemes(self, text):
|
80 |
+
all_pinyins = []
|
81 |
+
subtext = []
|
82 |
+
for chr in text:
|
83 |
+
if is_chinese(chr):
|
84 |
+
subtext.append(chr)
|
85 |
+
else:
|
86 |
+
if subtext != []:
|
87 |
+
subtext = ''.join(subtext)
|
88 |
+
pinyins = self.correct_pinyin_tone3(subtext)
|
89 |
+
pinyins = [f"<{i}>" for i in pinyins]
|
90 |
+
all_pinyins.append(' '+ ' '.join(pinyins)+ ' ')
|
91 |
+
all_pinyins.append(chr)
|
92 |
+
subtext = []
|
93 |
+
if subtext != []:
|
94 |
+
subtext = ''.join(subtext)
|
95 |
+
pinyins = self.correct_pinyin_tone3(subtext)
|
96 |
+
pinyins = [f"<{i}>" for i in pinyins]
|
97 |
+
all_pinyins.append(' '+ ' '.join(pinyins)+ ' ')
|
98 |
+
# char_embeds = self.prosody.expand_for_phone(char_embeds, count_phone)
|
99 |
+
return ''.join(all_pinyins)
|
100 |
+
|
101 |
+
def correct_pinyin_tone3(self, text):
|
102 |
+
pinyin_list = [
|
103 |
+
p[0]
|
104 |
+
for p in self.pinyin_parser.pinyin(text,
|
105 |
+
style=Style.TONE3,
|
106 |
+
strict=False,
|
107 |
+
neutral_tone_with_five=True)
|
108 |
+
]
|
109 |
+
if len(pinyin_list) >= 2:
|
110 |
+
for i in range(1, len(pinyin_list)):
|
111 |
+
try:
|
112 |
+
if re.findall(r'\d',
|
113 |
+
pinyin_list[i - 1])[0] == '3' and re.findall(
|
114 |
+
r'\d', pinyin_list[i])[0] == '3':
|
115 |
+
pinyin_list[i - 1] = pinyin_list[i - 1].replace(
|
116 |
+
'3', '2')
|
117 |
+
except IndexError:
|
118 |
+
pass
|
119 |
+
return pinyin_list
|
120 |
+
|
121 |
+
# def expand_for_phone(self, char_embeds, length): # length of phones for char
|
122 |
+
# if(char_embeds.size(0) > len(length)):
|
123 |
+
# print(char_embeds.shape, len(length))
|
124 |
+
# char_embeds = char_embeds[0:len(length),:]
|
125 |
+
# elif(char_embeds.size(0) < len(length)):
|
126 |
+
# print(char_embeds.shape, len(length))
|
127 |
+
# length = length[0:char_embeds.size(0)]
|
128 |
+
# expand_vecs = list()
|
129 |
+
# for vec, leng in zip(char_embeds, length):
|
130 |
+
# vec = vec.expand(leng, -1)
|
131 |
+
# expand_vecs.append(vec)
|
132 |
+
# expand_embeds = torch.cat(expand_vecs, 0)
|
133 |
+
# assert expand_embeds.size(0) == sum(length)
|
134 |
+
# return expand_embeds.numpy()
|
135 |
+
|
136 |
+
def __call__(self, text):
|
137 |
+
return self.chinese_to_phonemes(text)
|
SongBloom/g2p/pinyin/symbols.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_pause = ["sil", "eos", "sp", "#0", "#1", "#2", "#3"]
|
2 |
+
|
3 |
+
_initials = [
|
4 |
+
"^",
|
5 |
+
"b",
|
6 |
+
"c",
|
7 |
+
"ch",
|
8 |
+
"d",
|
9 |
+
"f",
|
10 |
+
"g",
|
11 |
+
"h",
|
12 |
+
"j",
|
13 |
+
"k",
|
14 |
+
"l",
|
15 |
+
"m",
|
16 |
+
"n",
|
17 |
+
"p",
|
18 |
+
"q",
|
19 |
+
"r",
|
20 |
+
"s",
|
21 |
+
"sh",
|
22 |
+
"t",
|
23 |
+
"x",
|
24 |
+
"z",
|
25 |
+
"zh",
|
26 |
+
]
|
27 |
+
|
28 |
+
_tones = ["1", "2", "3", "4", "5"]
|
29 |
+
|
30 |
+
_finals = [
|
31 |
+
"a",
|
32 |
+
"ai",
|
33 |
+
"an",
|
34 |
+
"ang",
|
35 |
+
"ao",
|
36 |
+
"e",
|
37 |
+
"ei",
|
38 |
+
"en",
|
39 |
+
"eng",
|
40 |
+
"er",
|
41 |
+
"i",
|
42 |
+
"ia",
|
43 |
+
"ian",
|
44 |
+
"iang",
|
45 |
+
"iao",
|
46 |
+
"ie",
|
47 |
+
"ii",
|
48 |
+
"iii",
|
49 |
+
"in",
|
50 |
+
"ing",
|
51 |
+
"iong",
|
52 |
+
"iou",
|
53 |
+
"o",
|
54 |
+
"ong",
|
55 |
+
"ou",
|
56 |
+
"u",
|
57 |
+
"ua",
|
58 |
+
"uai",
|
59 |
+
"uan",
|
60 |
+
"uang",
|
61 |
+
"uei",
|
62 |
+
"uen",
|
63 |
+
"ueng",
|
64 |
+
"uo",
|
65 |
+
"v",
|
66 |
+
"van",
|
67 |
+
"ve",
|
68 |
+
"vn",
|
69 |
+
]
|
70 |
+
|
71 |
+
symbols = _pause + _initials + [i + j for i in _finals for j in _tones]
|
SongBloom/models/base/sample.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
|
4 |
+
def multinomial(input: torch.Tensor, num_samples: int, replacement=False, *, generator=None):
|
5 |
+
"""torch.multinomial with arbitrary number of dimensions, and number of candidates on the last dimension.
|
6 |
+
|
7 |
+
Args:
|
8 |
+
input (torch.Tensor): The input tensor containing probabilities.
|
9 |
+
num_samples (int): Number of samples to draw.
|
10 |
+
replacement (bool): Whether to draw with replacement or not.
|
11 |
+
Keywords args:
|
12 |
+
generator (torch.Generator): A pseudorandom number generator for sampling.
|
13 |
+
Returns:
|
14 |
+
torch.Tensor: Last dimension contains num_samples indices
|
15 |
+
sampled from the multinomial probability distribution
|
16 |
+
located in the last dimension of tensor input.
|
17 |
+
"""
|
18 |
+
input_ = input.reshape(-1, input.shape[-1])
|
19 |
+
output_ = torch.multinomial(input_, num_samples=num_samples, replacement=replacement, generator=generator)
|
20 |
+
output = output_.reshape(*list(input.shape[:-1]), -1)
|
21 |
+
return output
|
22 |
+
|
23 |
+
|
24 |
+
def sample_top_k(probs: torch.Tensor, k: int) -> torch.Tensor:
|
25 |
+
"""Sample next token from top K values along the last dimension of the input probs tensor.
|
26 |
+
|
27 |
+
Args:
|
28 |
+
probs (torch.Tensor): Input probabilities with token candidates on the last dimension.
|
29 |
+
k (int): The k in “top-k”.
|
30 |
+
Returns:
|
31 |
+
torch.Tensor: Sampled tokens.
|
32 |
+
"""
|
33 |
+
top_k_value, _ = torch.topk(probs, k, dim=-1)
|
34 |
+
min_value_top_k = top_k_value[..., [-1]]
|
35 |
+
probs *= (probs >= min_value_top_k).float()
|
36 |
+
probs.div_(probs.sum(dim=-1, keepdim=True))
|
37 |
+
next_token = multinomial(probs, num_samples=1)
|
38 |
+
return next_token
|
39 |
+
|
40 |
+
|
41 |
+
def sample_top_p(probs: torch.Tensor, p: float) -> torch.Tensor:
|
42 |
+
"""Sample next token from top P probabilities along the last dimension of the input probs tensor.
|
43 |
+
|
44 |
+
Args:
|
45 |
+
probs (torch.Tensor): Input probabilities with token candidates on the last dimension.
|
46 |
+
p (int): The p in “top-p”.
|
47 |
+
Returns:
|
48 |
+
torch.Tensor: Sampled tokens.
|
49 |
+
"""
|
50 |
+
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
|
51 |
+
probs_sum = torch.cumsum(probs_sort, dim=-1)
|
52 |
+
mask = probs_sum - probs_sort > p
|
53 |
+
probs_sort *= (~mask).float()
|
54 |
+
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
|
55 |
+
next_token = multinomial(probs_sort, num_samples=1)
|
56 |
+
next_token = torch.gather(probs_idx, -1, next_token)
|
57 |
+
return next_token
|
SongBloom/models/base/utils.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import typing as tp
|
4 |
+
|
5 |
+
def length_to_mask(lengths: torch.Tensor, max_len: tp.Optional[int] = None) -> torch.Tensor:
|
6 |
+
"""Utility function to convert a tensor of sequence lengths to a mask (useful when working on padded sequences).
|
7 |
+
For example: [3, 5] => [[1, 1, 1, 0, 0], [1, 1, 1, 1, 1]]
|
8 |
+
|
9 |
+
Args:
|
10 |
+
lengths (torch.Tensor): tensor with lengths
|
11 |
+
max_len (int): can set the max length manually. Defaults to None.
|
12 |
+
Returns:
|
13 |
+
torch.Tensor: mask with 0s where there is pad tokens else 1s
|
14 |
+
"""
|
15 |
+
assert len(lengths.shape) == 1, "Length shape should be 1 dimensional."
|
16 |
+
final_length = lengths.max().item() if not max_len else max_len
|
17 |
+
final_length = max(final_length, 1) # if all seqs are of len zero we don't want a zero-size tensor
|
18 |
+
return torch.arange(final_length)[None, :].to(lengths.device) < lengths[:, None]
|
19 |
+
|
20 |
+
|
21 |
+
def create_sin_embedding(positions: torch.Tensor, dim: int, max_period: float = 10000,
|
22 |
+
dtype: torch.dtype = torch.float32) -> torch.Tensor:
|
23 |
+
"""Create sinusoidal positional embedding, with shape `[B, T, C]`.
|
24 |
+
|
25 |
+
Args:
|
26 |
+
positions (torch.Tensor): LongTensor of positions.
|
27 |
+
dim (int): Dimension of the embedding.
|
28 |
+
max_period (float): Maximum period of the cosine/sine functions.
|
29 |
+
dtype (torch.dtype or str): dtype to use to generate the embedding.
|
30 |
+
Returns:
|
31 |
+
torch.Tensor: Sinusoidal positional embedding.
|
32 |
+
"""
|
33 |
+
# We aim for BTC format
|
34 |
+
assert dim % 2 == 0
|
35 |
+
half_dim = dim // 2
|
36 |
+
positions = positions.to(dtype)
|
37 |
+
adim = torch.arange(half_dim, device=positions.device, dtype=dtype).view(1, 1, -1)
|
38 |
+
max_period_tensor = torch.full([], max_period, device=positions.device, dtype=dtype) # avoid sync point
|
39 |
+
phase = positions / (max_period_tensor ** (adim / (half_dim - 1)))
|
40 |
+
# phase = phase.to(torch.bfloat16)
|
41 |
+
return torch.cat([torch.cos(phase), torch.sin(phase)], dim=-1)
|
42 |
+
|
43 |
+
|
44 |
+
def create_norm_fn(norm_type: str, dim: int, **kwargs) -> nn.Module:
|
45 |
+
"""Create normalization module for transformer encoder layer.
|
46 |
+
|
47 |
+
Args:
|
48 |
+
norm_type (str): Normalization method.
|
49 |
+
dim (int): Dimension of the normalized layer.
|
50 |
+
**kwargs (dict): Additional parameters for normalization layer.
|
51 |
+
Returns:
|
52 |
+
nn.Module: Normalization module.
|
53 |
+
"""
|
54 |
+
if norm_type == 'layer_norm':
|
55 |
+
return nn.LayerNorm(dim, eps=1e-5, **kwargs)
|
56 |
+
else:
|
57 |
+
raise ValueError(f"Unknown norm type: {norm_type}")
|
SongBloom/models/musicgen/__init__.py
ADDED
File without changes
|
SongBloom/models/musicgen/conditioners/__init__.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import omegaconf
|
8 |
+
from .base import *
|
9 |
+
from .text import *
|
10 |
+
from .wav import *
|
11 |
+
|
12 |
+
KLASS = {
|
13 |
+
'phoneme_tokenizer': PhonemeTokenizerConditioner,
|
14 |
+
'audio_tokenizer_wrapper': AudioTokenizerConditioner,
|
15 |
+
}
|
16 |
+
|
17 |
+
def get_condition_fuser(fuser_cfgs) -> ConditionFuser:
|
18 |
+
"""Instantiate a condition fuser object."""
|
19 |
+
fuser_methods = ['sum', 'cross', 'prepend', 'input_interpolate']
|
20 |
+
fuse2cond = {k: fuser_cfgs[k] for k in fuser_methods}
|
21 |
+
kwargs = {k: v for k, v in fuser_cfgs.items() if k not in fuser_methods}
|
22 |
+
fuser = ConditionFuser(fuse2cond=fuse2cond, **kwargs)
|
23 |
+
return fuser
|
24 |
+
|
25 |
+
def get_conditioner_provider(cfg) -> ConditioningProvider:
|
26 |
+
"""Instantiate a conditioning model."""
|
27 |
+
|
28 |
+
dict_cfg = {} if cfg is None else dict(cfg)
|
29 |
+
conditioners: tp.Dict[str, BaseConditioner] = {}
|
30 |
+
|
31 |
+
# import pdb; pdb.set_trace()
|
32 |
+
for cond, cond_cfg in dict_cfg.items():
|
33 |
+
model_args = cond_cfg.copy()
|
34 |
+
model_type = model_args.pop('type')
|
35 |
+
conditioners[str(cond)] = KLASS[model_type](**model_args)
|
36 |
+
conditioner = ConditioningProvider(conditioners)
|
37 |
+
return conditioner
|
SongBloom/models/musicgen/conditioners/base.py
ADDED
@@ -0,0 +1,872 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import defaultdict
|
2 |
+
from copy import deepcopy
|
3 |
+
from dataclasses import dataclass, field
|
4 |
+
from itertools import chain
|
5 |
+
import logging
|
6 |
+
import typing as tp
|
7 |
+
import einops
|
8 |
+
|
9 |
+
import torch
|
10 |
+
from torch import nn
|
11 |
+
import torch.nn.functional as F
|
12 |
+
from torch.nn.utils.rnn import pad_sequence
|
13 |
+
|
14 |
+
from dataclasses import dataclass, field, fields, replace
|
15 |
+
|
16 |
+
from ..modules.streaming import StreamingModule
|
17 |
+
from ...base.utils import length_to_mask, create_sin_embedding
|
18 |
+
|
19 |
+
|
20 |
+
def collate(tensors: tp.List[torch.Tensor], dim: int = 0) -> tp.Tuple[torch.Tensor, torch.Tensor]:
|
21 |
+
"""Get a list of tensors and collate them to a single tensor. according to the following logic:
|
22 |
+
- `dim` specifies the time dimension which will be stacked and padded.
|
23 |
+
- The output will contain 1 new dimension (dimension index 0) which will be the size of
|
24 |
+
of the original list.
|
25 |
+
|
26 |
+
Args:
|
27 |
+
tensors (tp.List[torch.Tensor]): List of tensors to collate.
|
28 |
+
dim (int): Dimension which will be stacked and padded.
|
29 |
+
Returns:
|
30 |
+
tp.Tuple[torch.Tensor, torch.Tensor]:
|
31 |
+
torch.Tensor: Stacked and padded tensor. The output will contain 1 new dimension
|
32 |
+
(dimension index 0) which will be the size of the original list.
|
33 |
+
torch.Tensor: Tensor containing length of original tensor sizes (without padding).
|
34 |
+
"""
|
35 |
+
tensors = [x.transpose(0, dim) for x in tensors]
|
36 |
+
lens = torch.LongTensor([len(x) for x in tensors])
|
37 |
+
padded_tensors = pad_sequence(tensors)
|
38 |
+
padded_tensors = padded_tensors.transpose(0, 1)
|
39 |
+
padded_tensors = padded_tensors.transpose(1, dim + 1)
|
40 |
+
return padded_tensors, lens
|
41 |
+
|
42 |
+
|
43 |
+
|
44 |
+
@dataclass(order=True)
|
45 |
+
class PathInZip:
|
46 |
+
"""Hold a path of file within a zip file.
|
47 |
+
|
48 |
+
Args:
|
49 |
+
path (str): The convention is <path_to_zip>:<relative_path_inside_zip>.
|
50 |
+
Let's assume there is a zip file /some/location/foo.zip
|
51 |
+
and inside of it is a json file located at /data/file1.json,
|
52 |
+
Then we expect path = "/some/location/foo.zip:/data/file1.json".
|
53 |
+
"""
|
54 |
+
|
55 |
+
INFO_PATH_SEP = ':'
|
56 |
+
zip_path: str
|
57 |
+
file_path: str
|
58 |
+
|
59 |
+
def __init__(self, path: str) -> None:
|
60 |
+
split_path = path.split(self.INFO_PATH_SEP)
|
61 |
+
assert len(split_path) == 2
|
62 |
+
self.zip_path, self.file_path = split_path
|
63 |
+
|
64 |
+
@classmethod
|
65 |
+
def from_paths(cls, zip_path: str, file_path: str):
|
66 |
+
return cls(zip_path + cls.INFO_PATH_SEP + file_path)
|
67 |
+
|
68 |
+
def __str__(self) -> str:
|
69 |
+
return self.zip_path + self.INFO_PATH_SEP + self.file_path
|
70 |
+
|
71 |
+
|
72 |
+
@dataclass(order=True)
|
73 |
+
class BaseInfo:
|
74 |
+
|
75 |
+
@classmethod
|
76 |
+
def _dict2fields(cls, dictionary: dict):
|
77 |
+
return {
|
78 |
+
field.name: dictionary[field.name]
|
79 |
+
for field in fields(cls) if field.name in dictionary
|
80 |
+
}
|
81 |
+
# try:
|
82 |
+
# return {
|
83 |
+
# field.name: dictionary[field.name]
|
84 |
+
# for field in fields(cls) if field.name in dictionary
|
85 |
+
# }
|
86 |
+
# except:
|
87 |
+
# print(dictionary)
|
88 |
+
|
89 |
+
@classmethod
|
90 |
+
def from_dict(cls, dictionary: dict):
|
91 |
+
_dictionary = cls._dict2fields(dictionary)
|
92 |
+
return cls(**_dictionary)
|
93 |
+
|
94 |
+
def to_dict(self):
|
95 |
+
return {
|
96 |
+
field.name: self.__getattribute__(field.name)
|
97 |
+
for field in fields(self)
|
98 |
+
}
|
99 |
+
|
100 |
+
|
101 |
+
@dataclass(order=True)
|
102 |
+
class AudioMeta(BaseInfo):
|
103 |
+
path: str
|
104 |
+
duration: float
|
105 |
+
sample_rate: int
|
106 |
+
amplitude: tp.Optional[float] = None
|
107 |
+
weight: tp.Optional[float] = None
|
108 |
+
# info_path is used to load additional information about the audio file that is stored in zip files.
|
109 |
+
info_path: tp.Optional[PathInZip] = None
|
110 |
+
|
111 |
+
@classmethod
|
112 |
+
def from_dict(cls, dictionary: dict):
|
113 |
+
base = cls._dict2fields(dictionary)
|
114 |
+
if 'info_path' in base and base['info_path'] is not None:
|
115 |
+
base['info_path'] = PathInZip(base['info_path'])
|
116 |
+
return cls(**base)
|
117 |
+
|
118 |
+
def to_dict(self):
|
119 |
+
d = super().to_dict()
|
120 |
+
if d['info_path'] is not None:
|
121 |
+
d['info_path'] = str(d['info_path'])
|
122 |
+
return d
|
123 |
+
|
124 |
+
|
125 |
+
@dataclass(order=True)
|
126 |
+
class SegmentInfo(BaseInfo):
|
127 |
+
meta: AudioMeta
|
128 |
+
seek_time: float
|
129 |
+
# The following values are given once the audio is processed, e.g.
|
130 |
+
# at the target sample rate and target number of channels.
|
131 |
+
n_frames: int # actual number of frames without padding
|
132 |
+
total_frames: int # total number of frames, padding included
|
133 |
+
sample_rate: int # actual sample rate
|
134 |
+
channels: int # number of audio channels.
|
135 |
+
|
136 |
+
|
137 |
+
logger = logging.getLogger(__name__)
|
138 |
+
TextCondition = tp.Optional[str] # a text condition can be a string or None (if doesn't exist)
|
139 |
+
ConditionType = tp.Tuple[torch.Tensor, torch.Tensor] # condition, mask
|
140 |
+
|
141 |
+
|
142 |
+
class WavCondition(tp.NamedTuple):
|
143 |
+
wav: torch.Tensor
|
144 |
+
length: torch.Tensor
|
145 |
+
sample_rate: tp.List[int]
|
146 |
+
path: tp.List[tp.Optional[str]] = []
|
147 |
+
seek_time: tp.List[tp.Optional[float]] = []
|
148 |
+
|
149 |
+
|
150 |
+
class JointEmbedCondition(tp.NamedTuple):
|
151 |
+
wav: torch.Tensor
|
152 |
+
text: tp.List[tp.Optional[str]]
|
153 |
+
length: torch.Tensor
|
154 |
+
sample_rate: tp.List[int]
|
155 |
+
path: tp.List[tp.Optional[str]] = []
|
156 |
+
seek_time: tp.List[tp.Optional[float]] = []
|
157 |
+
|
158 |
+
|
159 |
+
@dataclass
|
160 |
+
class ConditioningAttributes:
|
161 |
+
text: tp.Dict[str, tp.Optional[str]] = field(default_factory=dict)
|
162 |
+
wav: tp.Dict[str, WavCondition] = field(default_factory=dict)
|
163 |
+
joint_embed: tp.Dict[str, JointEmbedCondition] = field(default_factory=dict)
|
164 |
+
|
165 |
+
def __getitem__(self, item):
|
166 |
+
return getattr(self, item)
|
167 |
+
|
168 |
+
@property
|
169 |
+
def text_attributes(self):
|
170 |
+
return self.text.keys()
|
171 |
+
|
172 |
+
@property
|
173 |
+
def wav_attributes(self):
|
174 |
+
return self.wav.keys()
|
175 |
+
|
176 |
+
@property
|
177 |
+
def joint_embed_attributes(self):
|
178 |
+
return self.joint_embed.keys()
|
179 |
+
|
180 |
+
@property
|
181 |
+
def attributes(self):
|
182 |
+
return {
|
183 |
+
"text": self.text_attributes,
|
184 |
+
"wav": self.wav_attributes,
|
185 |
+
"joint_embed": self.joint_embed_attributes,
|
186 |
+
}
|
187 |
+
|
188 |
+
def to_flat_dict(self):
|
189 |
+
return {
|
190 |
+
**{f"text.{k}": v for k, v in self.text.items()},
|
191 |
+
**{f"wav.{k}": v for k, v in self.wav.items()},
|
192 |
+
**{f"joint_embed.{k}": v for k, v in self.joint_embed.items()}
|
193 |
+
}
|
194 |
+
|
195 |
+
@classmethod
|
196 |
+
def from_flat_dict(cls, x):
|
197 |
+
out = cls()
|
198 |
+
for k, v in x.items():
|
199 |
+
kind, att = k.split(".")
|
200 |
+
out[kind][att] = v
|
201 |
+
return out
|
202 |
+
|
203 |
+
|
204 |
+
|
205 |
+
# class SegmentWithAttributes(SegmentInfo):
|
206 |
+
# """Base class for all dataclasses that are used for conditioning.
|
207 |
+
# All child classes should implement `to_condition_attributes` that converts
|
208 |
+
# the existing attributes to a dataclass of type ConditioningAttributes.
|
209 |
+
# """
|
210 |
+
# def to_condition_attributes(self) -> ConditioningAttributes:
|
211 |
+
# raise NotImplementedError()
|
212 |
+
|
213 |
+
|
214 |
+
|
215 |
+
def nullify_condition(condition: ConditionType, dim: int = 1):
|
216 |
+
"""Transform an input condition to a null condition.
|
217 |
+
The way it is done by converting it to a single zero vector similarly
|
218 |
+
to how it is done inside WhiteSpaceTokenizer and NoopTokenizer.
|
219 |
+
|
220 |
+
Args:
|
221 |
+
condition (ConditionType): A tuple of condition and mask (tuple[torch.Tensor, torch.Tensor])
|
222 |
+
dim (int): The dimension that will be truncated (should be the time dimension)
|
223 |
+
WARNING!: dim should not be the batch dimension!
|
224 |
+
Returns:
|
225 |
+
ConditionType: A tuple of null condition and mask
|
226 |
+
"""
|
227 |
+
assert dim != 0, "dim cannot be the batch dimension!"
|
228 |
+
assert isinstance(condition, tuple) and \
|
229 |
+
isinstance(condition[0], torch.Tensor) and \
|
230 |
+
isinstance(condition[1], torch.Tensor), "'nullify_condition' got an unexpected input type!"
|
231 |
+
cond, mask = condition
|
232 |
+
B = cond.shape[0]
|
233 |
+
last_dim = cond.dim() - 1
|
234 |
+
out = cond.transpose(dim, last_dim)
|
235 |
+
out = 0. * out[..., :1]
|
236 |
+
out = out.transpose(dim, last_dim)
|
237 |
+
mask = torch.zeros((B, 1), device=out.device).int()
|
238 |
+
assert cond.dim() == out.dim()
|
239 |
+
return out, mask
|
240 |
+
|
241 |
+
|
242 |
+
def nullify_wav(cond: WavCondition) -> WavCondition:
|
243 |
+
"""Transform a WavCondition to a nullified WavCondition.
|
244 |
+
It replaces the wav by a null tensor, forces its length to 0, and replaces metadata by dummy attributes.
|
245 |
+
|
246 |
+
Args:
|
247 |
+
cond (WavCondition): Wav condition with wav, tensor of shape [B, T].
|
248 |
+
Returns:
|
249 |
+
WavCondition: Nullified wav condition.
|
250 |
+
"""
|
251 |
+
#TODO by YCY, fix this to support zero-length input (as None)
|
252 |
+
null_wav, _ = nullify_condition((cond.wav, torch.zeros_like(cond.wav)), dim=cond.wav.dim() - 1) # B,1 all-zero
|
253 |
+
return WavCondition(
|
254 |
+
wav=null_wav,
|
255 |
+
length=torch.tensor([0] * cond.wav.shape[0], device=cond.wav.device),
|
256 |
+
sample_rate=cond.sample_rate,
|
257 |
+
path=[None] * cond.wav.shape[0],
|
258 |
+
seek_time=[None] * cond.wav.shape[0],
|
259 |
+
)
|
260 |
+
|
261 |
+
|
262 |
+
def nullify_joint_embed(embed: JointEmbedCondition) -> JointEmbedCondition:
|
263 |
+
"""Nullify the joint embedding condition by replacing it by a null tensor, forcing its length to 0,
|
264 |
+
and replacing metadata by dummy attributes.
|
265 |
+
|
266 |
+
Args:
|
267 |
+
cond (JointEmbedCondition): Joint embedding condition with wav and text, wav tensor of shape [B, C, T].
|
268 |
+
"""
|
269 |
+
null_wav, _ = nullify_condition((embed.wav, torch.zeros_like(embed.wav)), dim=embed.wav.dim() - 1)
|
270 |
+
return JointEmbedCondition(
|
271 |
+
wav=null_wav, text=[None] * len(embed.text),
|
272 |
+
length=torch.LongTensor([0]).to(embed.wav.device),
|
273 |
+
sample_rate=embed.sample_rate,
|
274 |
+
path=[None] * embed.wav.shape[0],
|
275 |
+
seek_time=[0] * embed.wav.shape[0],
|
276 |
+
)
|
277 |
+
|
278 |
+
|
279 |
+
|
280 |
+
class BaseConditioner(nn.Module):
|
281 |
+
"""Base model for all conditioner modules.
|
282 |
+
We allow the output dim to be different than the hidden dim for two reasons:
|
283 |
+
1) keep our LUTs small when the vocab is large;
|
284 |
+
2) make all condition dims consistent.
|
285 |
+
|
286 |
+
Args:
|
287 |
+
dim (int): Hidden dim of the model.
|
288 |
+
output_dim (int): Output dim of the conditioner.
|
289 |
+
"""
|
290 |
+
def __init__(self, dim: int, output_dim: int, input_token = False, padding_idx=None):
|
291 |
+
super().__init__()
|
292 |
+
self.dim = dim
|
293 |
+
self.output_dim = output_dim
|
294 |
+
if input_token:
|
295 |
+
self.output_proj = nn.Embedding(dim, output_dim, padding_idx)
|
296 |
+
else:
|
297 |
+
self.output_proj = nn.Linear(dim, output_dim)
|
298 |
+
|
299 |
+
def tokenize(self, *args, **kwargs) -> tp.Any:
|
300 |
+
"""Should be any part of the processing that will lead to a synchronization
|
301 |
+
point, e.g. BPE tokenization with transfer to the GPU.
|
302 |
+
|
303 |
+
The returned value will be saved and return later when calling forward().
|
304 |
+
"""
|
305 |
+
raise NotImplementedError()
|
306 |
+
|
307 |
+
def forward(self, inputs: tp.Any) -> ConditionType:
|
308 |
+
"""Gets input that should be used as conditioning (e.g, genre, description or a waveform).
|
309 |
+
Outputs a ConditionType, after the input data was embedded as a dense vector.
|
310 |
+
|
311 |
+
Returns:
|
312 |
+
ConditionType:
|
313 |
+
- A tensor of size [B, T, D] where B is the batch size, T is the length of the
|
314 |
+
output embedding and D is the dimension of the embedding.
|
315 |
+
- And a mask indicating where the padding tokens.
|
316 |
+
"""
|
317 |
+
raise NotImplementedError()
|
318 |
+
|
319 |
+
|
320 |
+
|
321 |
+
def dropout_condition(sample: ConditioningAttributes, condition_type: str, condition: str) -> ConditioningAttributes:
|
322 |
+
"""Utility function for nullifying an attribute inside an ConditioningAttributes object.
|
323 |
+
If the condition is of type "wav", then nullify it using `nullify_condition` function.
|
324 |
+
If the condition is of any other type, set its value to None.
|
325 |
+
Works in-place.
|
326 |
+
"""
|
327 |
+
if condition_type not in ['text', 'wav', 'joint_embed']:
|
328 |
+
raise ValueError(
|
329 |
+
"dropout_condition got an unexpected condition type!"
|
330 |
+
f" expected 'text', 'wav' or 'joint_embed' but got '{condition_type}'"
|
331 |
+
)
|
332 |
+
|
333 |
+
if condition not in getattr(sample, condition_type):
|
334 |
+
raise ValueError(
|
335 |
+
"dropout_condition received an unexpected condition!"
|
336 |
+
f" expected wav={sample.wav.keys()} and text={sample.text.keys()}"
|
337 |
+
f" but got '{condition}' of type '{condition_type}'!"
|
338 |
+
)
|
339 |
+
|
340 |
+
if condition_type == 'wav':
|
341 |
+
wav_cond = sample.wav[condition]
|
342 |
+
sample.wav[condition] = nullify_wav(wav_cond)
|
343 |
+
elif condition_type == 'joint_embed':
|
344 |
+
embed = sample.joint_embed[condition]
|
345 |
+
sample.joint_embed[condition] = nullify_joint_embed(embed)
|
346 |
+
else:
|
347 |
+
sample.text[condition] = None
|
348 |
+
|
349 |
+
return sample
|
350 |
+
|
351 |
+
|
352 |
+
class DropoutModule(nn.Module):
|
353 |
+
"""Base module for all dropout modules."""
|
354 |
+
def __init__(self, seed: int = 1234):
|
355 |
+
super().__init__()
|
356 |
+
self.rng = torch.Generator()
|
357 |
+
self.rng.manual_seed(seed)
|
358 |
+
|
359 |
+
|
360 |
+
class AttributeDropout(DropoutModule):
|
361 |
+
"""Dropout with a given probability per attribute.
|
362 |
+
This is different from the behavior of ClassifierFreeGuidanceDropout as this allows for attributes
|
363 |
+
to be dropped out separately. For example, "artist" can be dropped while "genre" remains.
|
364 |
+
This is in contrast to ClassifierFreeGuidanceDropout where if "artist" is dropped "genre"
|
365 |
+
must also be dropped.
|
366 |
+
|
367 |
+
Args:
|
368 |
+
p (tp.Dict[str, float]): A dict mapping between attributes and dropout probability. For example:
|
369 |
+
...
|
370 |
+
"genre": 0.1,
|
371 |
+
"artist": 0.5,
|
372 |
+
"wav": 0.25,
|
373 |
+
...
|
374 |
+
active_on_eval (bool, optional): Whether the dropout is active at eval. Default to False.
|
375 |
+
seed (int, optional): Random seed.
|
376 |
+
"""
|
377 |
+
def __init__(self, p: tp.Dict[str, tp.Dict[str, float]], active_on_eval: bool = False, seed: int = 1234):
|
378 |
+
super().__init__(seed=seed)
|
379 |
+
self.active_on_eval = active_on_eval
|
380 |
+
# construct dict that return the values from p otherwise 0
|
381 |
+
self.p = {}
|
382 |
+
for condition_type, probs in p.items():
|
383 |
+
self.p[condition_type] = defaultdict(lambda: 0, probs)
|
384 |
+
|
385 |
+
def forward(self, samples: tp.List[ConditioningAttributes]) -> tp.List[ConditioningAttributes]:
|
386 |
+
"""
|
387 |
+
Args:
|
388 |
+
samples (list[ConditioningAttributes]): List of conditions.
|
389 |
+
Returns:
|
390 |
+
list[ConditioningAttributes]: List of conditions after certain attributes were set to None.
|
391 |
+
"""
|
392 |
+
if not self.training and not self.active_on_eval:
|
393 |
+
return samples
|
394 |
+
|
395 |
+
samples = deepcopy(samples)
|
396 |
+
for condition_type, ps in self.p.items(): # for condition types [text, wav]
|
397 |
+
for condition, p in ps.items(): # for attributes of each type (e.g., [artist, genre])
|
398 |
+
# import pdb; pdb.set_trace()
|
399 |
+
# print(condition, p)
|
400 |
+
if torch.rand(1, generator=self.rng).item() < p:
|
401 |
+
for sample in samples:
|
402 |
+
dropout_condition(sample, condition_type, condition)
|
403 |
+
return samples
|
404 |
+
|
405 |
+
def __repr__(self):
|
406 |
+
return f"AttributeDropout({dict(self.p)})"
|
407 |
+
|
408 |
+
|
409 |
+
class ClassifierFreeGuidanceDropout(DropoutModule):
|
410 |
+
"""Classifier Free Guidance dropout.
|
411 |
+
All attributes are dropped with the same probability.
|
412 |
+
|
413 |
+
Args:
|
414 |
+
p (float): Probability to apply condition dropout during training.
|
415 |
+
seed (int): Random seed.
|
416 |
+
"""
|
417 |
+
def __init__(self, p: float, seed: int = 1234):
|
418 |
+
super().__init__(seed=seed)
|
419 |
+
self.p = p
|
420 |
+
|
421 |
+
def forward(self, samples: tp.List[ConditioningAttributes]) -> tp.List[ConditioningAttributes]:
|
422 |
+
"""
|
423 |
+
Args:
|
424 |
+
samples (list[ConditioningAttributes]): List of conditions.
|
425 |
+
Returns:
|
426 |
+
list[ConditioningAttributes]: List of conditions after all attributes were set to None.
|
427 |
+
"""
|
428 |
+
|
429 |
+
if not self.training:
|
430 |
+
return samples
|
431 |
+
# import pdb; pdb.set_trace()
|
432 |
+
# decide on which attributes to drop in a batched fashion
|
433 |
+
drop = torch.rand(1, generator=self.rng).item() < self.p
|
434 |
+
if not drop:
|
435 |
+
return samples
|
436 |
+
|
437 |
+
# nullify conditions of all attributes
|
438 |
+
samples = deepcopy(samples)
|
439 |
+
for condition_type in ["text", "wav","joint_embed"]:
|
440 |
+
for sample in samples:
|
441 |
+
for condition in sample.attributes[condition_type]:
|
442 |
+
dropout_condition(sample, condition_type, condition)
|
443 |
+
return samples
|
444 |
+
|
445 |
+
def __repr__(self):
|
446 |
+
return f"ClassifierFreeGuidanceDropout(p={self.p})"
|
447 |
+
|
448 |
+
|
449 |
+
class TextConditioner(BaseConditioner):
|
450 |
+
...
|
451 |
+
|
452 |
+
|
453 |
+
class WaveformConditioner(BaseConditioner):
|
454 |
+
"""Base class for all conditioners that take a waveform as input.
|
455 |
+
Classes that inherit must implement `_get_wav_embedding` that outputs
|
456 |
+
a continuous tensor, and `_downsampling_factor` that returns the down-sampling
|
457 |
+
factor of the embedding model.
|
458 |
+
|
459 |
+
Args:
|
460 |
+
dim (int): The internal representation dimension.
|
461 |
+
output_dim (int): Output dimension.
|
462 |
+
"""
|
463 |
+
def __init__(self, dim: int, output_dim: int, input_token = False, padding_idx=None):
|
464 |
+
super().__init__(dim, output_dim, input_token, padding_idx)
|
465 |
+
|
466 |
+
def tokenize(self, x: WavCondition) -> WavCondition:
|
467 |
+
wav, length, sample_rate, path, seek_time = x
|
468 |
+
assert length is not None
|
469 |
+
return WavCondition(wav, length, sample_rate, path, seek_time)
|
470 |
+
|
471 |
+
def _get_wav_embedding(self, x: WavCondition) -> torch.Tensor:
|
472 |
+
"""Gets as input a WavCondition and returns a dense embedding."""
|
473 |
+
raise NotImplementedError()
|
474 |
+
|
475 |
+
def _downsampling_factor(self):
|
476 |
+
"""Returns the downsampling factor of the embedding model."""
|
477 |
+
raise NotImplementedError()
|
478 |
+
|
479 |
+
def forward(self, x: WavCondition) -> ConditionType:
|
480 |
+
"""Extract condition embedding and mask from a waveform and its metadata.
|
481 |
+
Args:
|
482 |
+
x (WavCondition): Waveform condition containing raw waveform and metadata.
|
483 |
+
Returns:
|
484 |
+
ConditionType: a dense vector representing the conditioning along with its mask
|
485 |
+
"""
|
486 |
+
|
487 |
+
wav, lengths, *_ = x
|
488 |
+
# import pdb; pdb.set_trace()
|
489 |
+
with torch.no_grad():
|
490 |
+
embeds = self._get_wav_embedding(x)
|
491 |
+
embeds = embeds.to(self.output_proj.weight)
|
492 |
+
embeds = self.output_proj(embeds)
|
493 |
+
# import pdb; pdb.set_trace()
|
494 |
+
if lengths is not None:
|
495 |
+
lengths = lengths / self._downsampling_factor()
|
496 |
+
mask = length_to_mask(lengths, max_len=embeds.shape[1]).int() # type: ignore
|
497 |
+
else:
|
498 |
+
mask = torch.ones_like(embeds)
|
499 |
+
embeds = (embeds * mask.unsqueeze(2))
|
500 |
+
|
501 |
+
return embeds, mask
|
502 |
+
|
503 |
+
|
504 |
+
class JointEmbeddingConditioner(BaseConditioner):
|
505 |
+
"""Joint embedding conditioning supporting both audio or text conditioning.
|
506 |
+
|
507 |
+
Args:
|
508 |
+
dim (int): Dimension.
|
509 |
+
output_dim (int): Output dimension.
|
510 |
+
autocast_dtype (str): Autocast for the conditioner.
|
511 |
+
quantize (bool): Whether to quantize the CLAP embedding.
|
512 |
+
n_q (int): Number of residual quantizers (used if quantize is true).
|
513 |
+
bins (int): Quantizers' codebooks size (used if quantize is true).
|
514 |
+
kwargs: Additional parameters for residual vector quantizer.
|
515 |
+
"""
|
516 |
+
def __init__(self, dim: int, output_dim: int,
|
517 |
+
autocast_dtype: tp.Optional[str] = 'float32', #quantize: bool = False,
|
518 |
+
**kwargs):
|
519 |
+
super().__init__(dim=dim, output_dim=output_dim)
|
520 |
+
self.autocast_dtype = getattr(torch, autocast_dtype) if autocast_dtype is not None \
|
521 |
+
else None
|
522 |
+
if self.autocast_dtype is None:
|
523 |
+
logger.warning("JointEmbeddingConditioner has no autocast, this might lead to NaN.")
|
524 |
+
|
525 |
+
# # residual vector quantizer to discretize the conditioned embedding
|
526 |
+
# self.quantizer = None
|
527 |
+
# if quantize:
|
528 |
+
# from ..modules.quantization import ResidualVectorQuantizer
|
529 |
+
# self.quantizer = ResidualVectorQuantizer(dim, n_q=n_q, bins=bins, **kwargs)
|
530 |
+
|
531 |
+
def _get_embed(self, x: JointEmbedCondition) -> tp.Tuple[torch.Tensor, torch.Tensor]:
|
532 |
+
"""Get joint embedding in latent space from the inputs.
|
533 |
+
|
534 |
+
Returns:
|
535 |
+
tuple[torch.Tensor, torch.Tensor]: Tensor for the latent embedding
|
536 |
+
and corresponding empty indexes.
|
537 |
+
"""
|
538 |
+
raise NotImplementedError()
|
539 |
+
|
540 |
+
def forward(self, x: JointEmbedCondition) -> ConditionType:
|
541 |
+
with torch.cuda.amp.autocast(dtype=self.autocast_dtype):
|
542 |
+
embed, empty_idx = self._get_embed(x)
|
543 |
+
if self.quantizer is not None:
|
544 |
+
embed = embed.view(-1, self.dim, 1)
|
545 |
+
q_res = self.quantizer(embed, frame_rate=1)
|
546 |
+
out_embed = q_res.x.view(-1, self.dim)
|
547 |
+
else:
|
548 |
+
out_embed = embed
|
549 |
+
out_embed = self.output_proj(out_embed).view(-1, 1, self.output_dim)
|
550 |
+
mask = torch.ones(*out_embed.shape[:2], device=out_embed.device)
|
551 |
+
mask[empty_idx, :] = 0 # zero-out index where the input is non-existant
|
552 |
+
out_embed = (out_embed * mask.unsqueeze(-1))
|
553 |
+
return out_embed, mask
|
554 |
+
|
555 |
+
def tokenize(self, x: JointEmbedCondition) -> JointEmbedCondition:
|
556 |
+
return x
|
557 |
+
|
558 |
+
|
559 |
+
class ConditioningProvider(nn.Module):
|
560 |
+
"""Prepare and provide conditions given all the supported conditioners.
|
561 |
+
|
562 |
+
Args:
|
563 |
+
conditioners (dict): Dictionary of conditioners.
|
564 |
+
"""
|
565 |
+
def __init__(self, conditioners: tp.Dict[str, BaseConditioner]):
|
566 |
+
super().__init__()
|
567 |
+
self.conditioners = nn.ModuleDict(conditioners)
|
568 |
+
def _check_conditioner_type(c):
|
569 |
+
if isinstance(c, WaveformConditioner):
|
570 |
+
return "wav"
|
571 |
+
elif isinstance(c, TextConditioner):
|
572 |
+
return "text"
|
573 |
+
elif isinstance(c, JointEmbeddingConditioner):
|
574 |
+
return "joint_embed"
|
575 |
+
else:
|
576 |
+
raise NotImplementedError(f"{type(c)} are not Implemented!")
|
577 |
+
self.conditioner_type = {k: _check_conditioner_type(self.conditioners[k]) for k in self.conditioners}
|
578 |
+
|
579 |
+
|
580 |
+
@property
|
581 |
+
def joint_embed_conditions(self):
|
582 |
+
return [k for k, v in self.conditioners.items() if isinstance(v, JointEmbeddingConditioner)]
|
583 |
+
|
584 |
+
@property
|
585 |
+
def has_joint_embed_conditions(self):
|
586 |
+
return len(self.joint_embed_conditions) > 0
|
587 |
+
|
588 |
+
@property
|
589 |
+
def text_conditions(self):
|
590 |
+
return [k for k, v in self.conditioners.items() if isinstance(v, TextConditioner)]
|
591 |
+
|
592 |
+
@property
|
593 |
+
def wav_conditions(self):
|
594 |
+
return [k for k, v in self.conditioners.items() if isinstance(v, WaveformConditioner)]
|
595 |
+
|
596 |
+
@property
|
597 |
+
def has_wav_condition(self):
|
598 |
+
return len(self.wav_conditions) > 0
|
599 |
+
|
600 |
+
def tokenize(self, inputs: tp.List[ConditioningAttributes]) -> tp.Dict[str, tp.Any]:
|
601 |
+
"""Match attributes/wavs with existing conditioners in self, and compute tokenize them accordingly.
|
602 |
+
This should be called before starting any real GPU work to avoid synchronization points.
|
603 |
+
This will return a dict matching conditioner names to their arbitrary tokenized representations.
|
604 |
+
|
605 |
+
Args:
|
606 |
+
inputs (list[ConditioningAttributes]): List of ConditioningAttributes objects containing
|
607 |
+
text and wav conditions.
|
608 |
+
"""
|
609 |
+
assert all([isinstance(x, ConditioningAttributes) for x in inputs]), (
|
610 |
+
"Got unexpected types input for conditioner! should be tp.List[ConditioningAttributes]",
|
611 |
+
f" but types were {set([type(x) for x in inputs])}"
|
612 |
+
)
|
613 |
+
|
614 |
+
# import pdb; pdb.set_trace()
|
615 |
+
output = {}
|
616 |
+
text = self._collate_text(inputs)
|
617 |
+
wavs = self._collate_wavs(inputs)
|
618 |
+
joint_embeds = self._collate_joint_embeds(inputs)
|
619 |
+
|
620 |
+
assert set(text.keys() | wavs.keys() | joint_embeds.keys()).issubset(set(self.conditioners.keys())), (
|
621 |
+
f"Got an unexpected attribute! Expected {self.conditioners.keys()}, ",
|
622 |
+
f"got {text.keys(), wavs.keys(), joint_embeds.keys()}"
|
623 |
+
)
|
624 |
+
|
625 |
+
for attribute, batch in chain(text.items(), wavs.items(), joint_embeds.items()):
|
626 |
+
output[attribute] = self.conditioners[attribute].tokenize(batch)
|
627 |
+
return output
|
628 |
+
|
629 |
+
def forward(self, tokenized: tp.Dict[str, tp.Any], texts = None) -> tp.Dict[str, ConditionType]:
|
630 |
+
"""Compute pairs of `(embedding, mask)` using the configured conditioners and the tokenized representations.
|
631 |
+
The output is for example:
|
632 |
+
{
|
633 |
+
"genre": (torch.Tensor([B, 1, D_genre]), torch.Tensor([B, 1])),
|
634 |
+
"description": (torch.Tensor([B, T_desc, D_desc]), torch.Tensor([B, T_desc])),
|
635 |
+
...
|
636 |
+
}
|
637 |
+
|
638 |
+
Args:
|
639 |
+
tokenized (dict): Dict of tokenized representations as returned by `tokenize()`.
|
640 |
+
"""
|
641 |
+
# import pdb; pdb.set_trace()
|
642 |
+
output = {}
|
643 |
+
for attribute, inputs in tokenized.items():
|
644 |
+
if attribute == 'self_wav' and texts is not None:
|
645 |
+
condition, mask = self.conditioners[attribute](inputs, texts = texts)
|
646 |
+
else:
|
647 |
+
condition, mask = self.conditioners[attribute](inputs)
|
648 |
+
output[attribute] = (condition, mask)
|
649 |
+
return output
|
650 |
+
|
651 |
+
def _collate_text(self, samples: tp.List[ConditioningAttributes]) -> tp.Dict[str, tp.List[tp.Optional[str]]]:
|
652 |
+
"""Given a list of ConditioningAttributes objects, compile a dictionary where the keys
|
653 |
+
are the attributes and the values are the aggregated input per attribute.
|
654 |
+
For example:
|
655 |
+
Input:
|
656 |
+
[
|
657 |
+
ConditioningAttributes(text={"genre": "Rock", "description": "A rock song with a guitar solo"}, wav=...),
|
658 |
+
ConditioningAttributes(text={"genre": "Hip-hop", "description": "A hip-hop verse"}, wav=...),
|
659 |
+
]
|
660 |
+
Output:
|
661 |
+
{
|
662 |
+
"genre": ["Rock", "Hip-hop"],
|
663 |
+
"description": ["A rock song with a guitar solo", "A hip-hop verse"]
|
664 |
+
}
|
665 |
+
|
666 |
+
Args:
|
667 |
+
samples (list of ConditioningAttributes): List of ConditioningAttributes samples.
|
668 |
+
Returns:
|
669 |
+
dict[str, list[str, optional]]: A dictionary mapping an attribute name to text batch.
|
670 |
+
"""
|
671 |
+
out: tp.Dict[str, tp.List[tp.Optional[str]]] = defaultdict(list)
|
672 |
+
texts = [x.text for x in samples]
|
673 |
+
for text in texts:
|
674 |
+
for condition in self.text_conditions:
|
675 |
+
out[condition].append(text[condition])
|
676 |
+
return out
|
677 |
+
|
678 |
+
def _collate_wavs(self, samples: tp.List[ConditioningAttributes]) -> tp.Dict[str, WavCondition]:
|
679 |
+
"""Generate a dict where the keys are attributes by which we fetch similar wavs,
|
680 |
+
and the values are Tensors of wavs according to said attributes.
|
681 |
+
|
682 |
+
*Note*: by the time the samples reach this function, each sample should have some waveform
|
683 |
+
inside the "wav" attribute. It should be either:
|
684 |
+
1. A real waveform
|
685 |
+
2. A null waveform due to the sample having no similar waveforms (nullified by the dataset)
|
686 |
+
3. A null waveform due to it being dropped in a dropout module (nullified by dropout)
|
687 |
+
|
688 |
+
Args:
|
689 |
+
samples (list of ConditioningAttributes): List of ConditioningAttributes samples.
|
690 |
+
Returns:
|
691 |
+
dict[str, WavCondition]: A dictionary mapping an attribute name to wavs.
|
692 |
+
"""
|
693 |
+
# import pdb; pdb.set_trace()
|
694 |
+
wavs = defaultdict(list)
|
695 |
+
lengths = defaultdict(list)
|
696 |
+
sample_rates = defaultdict(list)
|
697 |
+
paths = defaultdict(list)
|
698 |
+
seek_times = defaultdict(list)
|
699 |
+
out: tp.Dict[str, WavCondition] = {}
|
700 |
+
|
701 |
+
for sample in samples:
|
702 |
+
for attribute in self.wav_conditions:
|
703 |
+
wav, length, sample_rate, path, seek_time = sample.wav[attribute]
|
704 |
+
assert wav.dim() == 3, f"Got wav with dim={wav.dim()}, but expected 3 [1, C, T]"
|
705 |
+
assert wav.size(0) == 1, f"Got wav [B, C, T] with shape={wav.shape}, but expected B == 1"
|
706 |
+
# mono-channel conditioning
|
707 |
+
# wav = wav.mean(1, keepdim=True) # [1, 1, T] # by cyy, 为了实现后续功能注释掉了,请手动确保channel=1,or 输入channel 符合预期
|
708 |
+
wavs[attribute].append(wav.flatten()) # [C*T]
|
709 |
+
lengths[attribute].append(length)
|
710 |
+
sample_rates[attribute].extend(sample_rate)
|
711 |
+
paths[attribute].extend(path)
|
712 |
+
seek_times[attribute].extend(seek_time)
|
713 |
+
|
714 |
+
# stack all wavs to a single tensor
|
715 |
+
for attribute in self.wav_conditions:
|
716 |
+
stacked_wav, _ = collate(wavs[attribute], dim=0)
|
717 |
+
out[attribute] = WavCondition(
|
718 |
+
stacked_wav.unsqueeze(1), torch.cat(lengths[attribute]), sample_rates[attribute],
|
719 |
+
paths[attribute], seek_times[attribute])
|
720 |
+
|
721 |
+
return out
|
722 |
+
|
723 |
+
def _collate_joint_embeds(self, samples: tp.List[ConditioningAttributes]) -> tp.Dict[str, JointEmbedCondition]:
|
724 |
+
"""Generate a dict where the keys are attributes by which we compute joint embeddings,
|
725 |
+
and the values are Tensors of pre-computed embeddings and the corresponding text attributes.
|
726 |
+
|
727 |
+
Args:
|
728 |
+
samples (list[ConditioningAttributes]): List of ConditioningAttributes samples.
|
729 |
+
Returns:
|
730 |
+
A dictionary mapping an attribute name to joint embeddings.
|
731 |
+
"""
|
732 |
+
texts = defaultdict(list)
|
733 |
+
wavs = defaultdict(list)
|
734 |
+
lengths = defaultdict(list)
|
735 |
+
sample_rates = defaultdict(list)
|
736 |
+
paths = defaultdict(list)
|
737 |
+
seek_times = defaultdict(list)
|
738 |
+
channels: int = 0
|
739 |
+
|
740 |
+
out = {}
|
741 |
+
for sample in samples:
|
742 |
+
for attribute in self.joint_embed_conditions:
|
743 |
+
wav, text, length, sample_rate, path, seek_time = sample.joint_embed[attribute]
|
744 |
+
assert wav.dim() == 3
|
745 |
+
if channels == 0:
|
746 |
+
channels = wav.size(1)
|
747 |
+
else:
|
748 |
+
assert channels == wav.size(1), "not all audio has same number of channels in batch"
|
749 |
+
assert wav.size(0) == 1, "Expecting single-wav batch in the collate method"
|
750 |
+
wav = einops.rearrange(wav, "b c t -> (b c t)") # [1, C, T] => [C * T]
|
751 |
+
wavs[attribute].append(wav)
|
752 |
+
texts[attribute].extend(text)
|
753 |
+
lengths[attribute].append(length)
|
754 |
+
sample_rates[attribute].extend(sample_rate)
|
755 |
+
paths[attribute].extend(path)
|
756 |
+
seek_times[attribute].extend(seek_time)
|
757 |
+
|
758 |
+
for attribute in self.joint_embed_conditions:
|
759 |
+
stacked_texts = texts[attribute]
|
760 |
+
stacked_paths = paths[attribute]
|
761 |
+
stacked_seek_times = seek_times[attribute]
|
762 |
+
stacked_wavs = pad_sequence(wavs[attribute])
|
763 |
+
stacked_wavs = einops.rearrange(stacked_wavs, "(c t) b -> b c t", c=channels)
|
764 |
+
stacked_sample_rates = sample_rates[attribute]
|
765 |
+
stacked_lengths = torch.cat(lengths[attribute])
|
766 |
+
|
767 |
+
assert stacked_lengths.size(0) == stacked_wavs.size(0)
|
768 |
+
assert len(stacked_sample_rates) == stacked_wavs.size(0)
|
769 |
+
assert len(stacked_texts) == stacked_wavs.size(0)
|
770 |
+
out[attribute] = JointEmbedCondition(
|
771 |
+
text=stacked_texts, wav=stacked_wavs,
|
772 |
+
length=stacked_lengths, sample_rate=stacked_sample_rates,
|
773 |
+
path=stacked_paths, seek_time=stacked_seek_times)
|
774 |
+
|
775 |
+
return out
|
776 |
+
|
777 |
+
|
778 |
+
class ConditionFuser(StreamingModule):
|
779 |
+
"""Condition fuser handles the logic to combine the different conditions
|
780 |
+
to the actual model input.
|
781 |
+
|
782 |
+
Args:
|
783 |
+
fuse2cond (tp.Dict[str, str]): A dictionary that says how to fuse
|
784 |
+
each condition. For example:
|
785 |
+
{
|
786 |
+
"prepend": ["description"],
|
787 |
+
"sum": ["genre", "bpm"],
|
788 |
+
"cross": ["description"],
|
789 |
+
}
|
790 |
+
cross_attention_pos_emb (bool, optional): Use positional embeddings in cross attention.
|
791 |
+
cross_attention_pos_emb_scale (int): Scale for positional embeddings in cross attention if used.
|
792 |
+
"""
|
793 |
+
FUSING_METHODS = ["sum", "prepend", "cross", "input_interpolate"]
|
794 |
+
|
795 |
+
def __init__(self, fuse2cond: tp.Dict[str, tp.List[str]], cross_attention_pos_emb: bool = False,
|
796 |
+
cross_attention_pos_emb_scale: float = 1.0):
|
797 |
+
super().__init__()
|
798 |
+
assert all(
|
799 |
+
[k in self.FUSING_METHODS for k in fuse2cond.keys()]
|
800 |
+
), f"Got invalid fuse method, allowed methods: {self.FUSING_METHODS}"
|
801 |
+
self.cross_attention_pos_emb = cross_attention_pos_emb
|
802 |
+
self.cross_attention_pos_emb_scale = cross_attention_pos_emb_scale
|
803 |
+
self.fuse2cond: tp.Dict[str, tp.List[str]] = fuse2cond
|
804 |
+
self.cond2fuse: tp.Dict[str, str] = {}
|
805 |
+
for fuse_method, conditions in fuse2cond.items():
|
806 |
+
for condition in conditions:
|
807 |
+
self.cond2fuse[condition] = fuse_method
|
808 |
+
|
809 |
+
def forward(
|
810 |
+
self,
|
811 |
+
input: torch.Tensor,
|
812 |
+
conditions: tp.Dict[str, ConditionType]
|
813 |
+
) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
|
814 |
+
"""Fuse the conditions to the provided model input.
|
815 |
+
|
816 |
+
Args:
|
817 |
+
input (torch.Tensor): Transformer input.
|
818 |
+
conditions (dict[str, ConditionType]): Dict of conditions.
|
819 |
+
Returns:
|
820 |
+
tuple[torch.Tensor, torch.Tensor]: The first tensor is the transformer input
|
821 |
+
after the conditions have been fused. The second output tensor is the tensor
|
822 |
+
used for cross-attention or None if no cross attention inputs exist.
|
823 |
+
"""
|
824 |
+
# import pdb; pdb.set_trace()
|
825 |
+
B, T, _ = input.shape
|
826 |
+
|
827 |
+
if 'offsets' in self._streaming_state:
|
828 |
+
first_step = False
|
829 |
+
offsets = self._streaming_state['offsets']
|
830 |
+
else:
|
831 |
+
first_step = True
|
832 |
+
offsets = torch.zeros(input.shape[0], dtype=torch.long, device=input.device)
|
833 |
+
|
834 |
+
assert set(conditions.keys()).issubset(set(self.cond2fuse.keys())), \
|
835 |
+
f"given conditions contain unknown attributes for fuser, " \
|
836 |
+
f"expected {self.cond2fuse.keys()}, got {conditions.keys()}"
|
837 |
+
cross_attention_output = None
|
838 |
+
prepend_input = input[:, :0]
|
839 |
+
for cond_type, (cond, cond_mask) in conditions.items():
|
840 |
+
op = self.cond2fuse[cond_type]
|
841 |
+
if op == 'sum':
|
842 |
+
input += cond
|
843 |
+
elif op == 'input_interpolate':
|
844 |
+
cond = einops.rearrange(cond, "b t d -> b d t")
|
845 |
+
cond = F.interpolate(cond, size=input.shape[1])
|
846 |
+
input += einops.rearrange(cond, "b d t -> b t d")
|
847 |
+
elif op == 'prepend':
|
848 |
+
prepend_input = torch.cat([cond.to(input.dtype), prepend_input], dim=1)
|
849 |
+
# NOTE 这里cond应该在后,这样顺序才符合配置文件,否则为逆序
|
850 |
+
# 但是之前实验是这样的为了保持一致就没改
|
851 |
+
elif op == 'cross':
|
852 |
+
if cross_attention_output is not None:
|
853 |
+
cross_attention_output = torch.cat([cross_attention_output, cond], dim=1)
|
854 |
+
else:
|
855 |
+
cross_attention_output = cond
|
856 |
+
else:
|
857 |
+
raise ValueError(f"unknown op ({op})")
|
858 |
+
|
859 |
+
if self.cross_attention_pos_emb and cross_attention_output is not None:
|
860 |
+
positions = torch.arange(
|
861 |
+
cross_attention_output.shape[1],
|
862 |
+
device=cross_attention_output.device
|
863 |
+
).view(1, -1, 1)
|
864 |
+
pos_emb = create_sin_embedding(positions, cross_attention_output.shape[-1])
|
865 |
+
cross_attention_output = cross_attention_output + self.cross_attention_pos_emb_scale * pos_emb
|
866 |
+
|
867 |
+
if first_step:
|
868 |
+
input = torch.cat([prepend_input, input], dim=1)
|
869 |
+
if self._is_streaming:
|
870 |
+
self._streaming_state['offsets'] = offsets + T
|
871 |
+
|
872 |
+
return input, cross_attention_output
|
SongBloom/models/musicgen/conditioners/text.py
ADDED
@@ -0,0 +1,254 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .base import *
|
2 |
+
|
3 |
+
import spacy
|
4 |
+
import warnings
|
5 |
+
import random
|
6 |
+
import hashlib
|
7 |
+
from transformers import RobertaTokenizer, T5EncoderModel, T5Tokenizer, AutoTokenizer, XLMRobertaModel, XLMRobertaTokenizer # type: ignore
|
8 |
+
from num2words import num2words
|
9 |
+
|
10 |
+
def hash_trick(word: str, vocab_size: int) -> int:
|
11 |
+
"""Hash trick to pair each word with an index
|
12 |
+
|
13 |
+
Args:
|
14 |
+
word (str): word we wish to convert to an index
|
15 |
+
vocab_size (int): size of the vocabulary
|
16 |
+
Returns:
|
17 |
+
int: index of the word in the embedding LUT
|
18 |
+
"""
|
19 |
+
|
20 |
+
hash = int(hashlib.sha256(word.encode("utf-8")).hexdigest(), 16)
|
21 |
+
return hash % vocab_size
|
22 |
+
|
23 |
+
|
24 |
+
|
25 |
+
class PhonemeTokenizerConditioner(TextConditioner):
|
26 |
+
def __init__(self,
|
27 |
+
output_dim: int,
|
28 |
+
vocab_list,
|
29 |
+
max_len = 600,
|
30 |
+
max_sentence_per_structure = 50,
|
31 |
+
structure_tokens=None,
|
32 |
+
structure_split_tokens=[','],
|
33 |
+
sentence_split_tokens=['.'],
|
34 |
+
mode='sum',
|
35 |
+
structure_output_dim = 64,
|
36 |
+
sentence_output_dim = 64,
|
37 |
+
max_duration = 120,
|
38 |
+
interpolate = False,
|
39 |
+
):
|
40 |
+
|
41 |
+
self.vocab_list = vocab_list
|
42 |
+
self.max_len = max_len
|
43 |
+
self.mode = mode
|
44 |
+
self.max_sentence_per_structure = max_sentence_per_structure
|
45 |
+
voc_size = len(self.vocab_list)
|
46 |
+
self.interpolate = interpolate
|
47 |
+
|
48 |
+
if structure_tokens is None:
|
49 |
+
structure_tokens = [i for i in vocab_list if len(i) > 1 and i[0] == '[' and i[-1] == ']']
|
50 |
+
self.structure_token_ids = [vocab_list.index(i) for i in structure_tokens if i in vocab_list]
|
51 |
+
self.structure_split_token_ids = [vocab_list.index(i) for i in structure_split_tokens]
|
52 |
+
self.sentence_split_token_ids = [vocab_list.index(i) for i in sentence_split_tokens]
|
53 |
+
|
54 |
+
# here initialize a output_proj (nn.Embedding) layer
|
55 |
+
# By default the first vocab is "" (null)
|
56 |
+
if mode == 'sum':
|
57 |
+
content_output_dim = output_dim
|
58 |
+
sentence_output_dim = output_dim
|
59 |
+
structure_output_dim = output_dim
|
60 |
+
else: # concat
|
61 |
+
content_output_dim = output_dim - sentence_output_dim - structure_output_dim # by default
|
62 |
+
|
63 |
+
super().__init__(voc_size, content_output_dim, input_token=True, padding_idx=0)
|
64 |
+
if self.mode != 'sum':
|
65 |
+
self.special_emb = nn.Embedding(len(self.structure_token_ids)+len(self.structure_split_token_ids)+len(self.sentence_split_token_ids)+1,
|
66 |
+
structure_output_dim, padding_idx=0)
|
67 |
+
|
68 |
+
self.blank_emb = nn.Parameter(torch.zeros(1, output_dim), requires_grad=False)
|
69 |
+
|
70 |
+
# the first index is "empty structure" token
|
71 |
+
self.sentence_idx_in_structure_emb = nn.Embedding(max_sentence_per_structure, sentence_output_dim, padding_idx=0)
|
72 |
+
|
73 |
+
# print("max_len", self.max_len)
|
74 |
+
print(self.structure_token_ids)
|
75 |
+
|
76 |
+
self.resolution = max_duration / max_len # e.g., 120 / 600 = 0.2s
|
77 |
+
print(self.__class__, f"resolution = {self.resolution}")
|
78 |
+
|
79 |
+
def tokenize(self, x: tp.List[tp.Optional[str]]) -> tp.Dict[str, torch.Tensor]:
|
80 |
+
inputs = []
|
81 |
+
for xx in x:
|
82 |
+
xx = '' if xx is None else xx
|
83 |
+
vocab_id = [self.vocab_list.index(item) for item in xx.split(" ") if item in self.vocab_list]
|
84 |
+
inputs.append(torch.tensor(vocab_id).long()) # [T]
|
85 |
+
return inputs
|
86 |
+
|
87 |
+
|
88 |
+
def interpolate_with_structure_duration(self, special_tokens, embeds, structure_dur):
|
89 |
+
# embeds: [T, N]
|
90 |
+
def sec2idx(sec): # convert duration sec to token index
|
91 |
+
return int(sec / self.resolution)
|
92 |
+
|
93 |
+
def target_token_types2list(tokens, target_token_types):
|
94 |
+
|
95 |
+
is_target_list = torch.any(torch.stack([tokens == i for i in target_token_types], dim=-1), dim=-1)
|
96 |
+
is_target_list = torch.where(is_target_list)[0].tolist()
|
97 |
+
return is_target_list
|
98 |
+
|
99 |
+
structure_ids = []
|
100 |
+
for (structure, st, et) in structure_dur:
|
101 |
+
structure_ids.append([structure, sec2idx(st), sec2idx(et)])
|
102 |
+
|
103 |
+
"""
|
104 |
+
interpolate embeddings of each structure according to its duration
|
105 |
+
"""
|
106 |
+
is_structure_list = target_token_types2list(special_tokens, self.structure_token_ids)
|
107 |
+
is_structure_list.append(special_tokens.shape[-1])
|
108 |
+
|
109 |
+
split_tokens = deepcopy(self.structure_split_token_ids)
|
110 |
+
split_tokens.extend(self.sentence_split_token_ids)
|
111 |
+
# is_split_list = target_token_types2list(special_tokens, split_tokens)
|
112 |
+
|
113 |
+
|
114 |
+
interpolated_embeds = embeds[:is_structure_list[0]]
|
115 |
+
for i, st in enumerate(is_structure_list[:-1]):
|
116 |
+
# (lorry) Explain "-tmp":
|
117 |
+
# All structures are connected with " , " token,
|
118 |
+
# " ," is also the final token of each structure except the final one,
|
119 |
+
# but here we dont want to interpolate " , " token
|
120 |
+
tmp = 1
|
121 |
+
if i == len(is_structure_list[:-1]) - 1: # the final structure, no need for "-1"
|
122 |
+
tmp = 0
|
123 |
+
|
124 |
+
# print(st, is_structure_list[i+1]-tmp)
|
125 |
+
to_interpolate = embeds[st: is_structure_list[i+1] - tmp]
|
126 |
+
interpolate_size = structure_ids[i][2] - structure_ids[i][1] - tmp
|
127 |
+
# print(interpolate_size)
|
128 |
+
|
129 |
+
#import pdb; pdb.set_trace()
|
130 |
+
# print(interpolated_embeds.shape, to_interpolate.shape, interpolate_size, )
|
131 |
+
if to_interpolate.shape[0] == 0:
|
132 |
+
import pdb; pdb.set_trace()
|
133 |
+
this_interpolated_embeds = F.interpolate(to_interpolate.unsqueeze(0).transpose(2, 1),
|
134 |
+
size=interpolate_size,
|
135 |
+
mode='nearest-exact').squeeze(0).transpose(1, 0)
|
136 |
+
|
137 |
+
if tmp == 1:
|
138 |
+
interpolated_embeds = torch.cat((interpolated_embeds, this_interpolated_embeds,
|
139 |
+
embeds[is_structure_list[i+1]].unsqueeze(0)), 0)
|
140 |
+
else:
|
141 |
+
interpolated_embeds = torch.cat((interpolated_embeds, this_interpolated_embeds), 0)
|
142 |
+
return interpolated_embeds
|
143 |
+
|
144 |
+
|
145 |
+
def forward(self, batch_tokens: tp.List, structure_dur = None) -> ConditionType:
|
146 |
+
"""
|
147 |
+
Encode token_id into three types of embeddings:
|
148 |
+
1) content embedding: phoneme only (or meaningful contents to be sung out)
|
149 |
+
2) structure embedding: structure / separation embeddings, including structures (verse/chorus/...), separators (. / ,)
|
150 |
+
The two above share the same embedding layer, can be changed to separate embedding layers.
|
151 |
+
3) sentence_idx embedding (per structure):
|
152 |
+
"""
|
153 |
+
embeds_batch = []
|
154 |
+
# print(batch_tokens)
|
155 |
+
for b in range(len(batch_tokens)):
|
156 |
+
tokens = batch_tokens[b]
|
157 |
+
|
158 |
+
content_tokens = torch.zeros_like(tokens)
|
159 |
+
special_tokens = torch.zeros_like(tokens)
|
160 |
+
sentence_idx_in_structure_tokens = torch.zeros_like(tokens)
|
161 |
+
|
162 |
+
current_structure_idx = 1
|
163 |
+
current_sentence_in_structure_idx = 1
|
164 |
+
current_structure = 0
|
165 |
+
|
166 |
+
for i in range(tokens.shape[-1]):
|
167 |
+
token = tokens[i]
|
168 |
+
if token in self.structure_token_ids: # structure token
|
169 |
+
# only update structure token, leave content and sentence index token null (default 0)
|
170 |
+
if self.mode == 'sum':
|
171 |
+
special_tokens[i] = token
|
172 |
+
else:
|
173 |
+
special_tokens[i] = self.structure_token_ids.index(token) + 1
|
174 |
+
current_structure = token
|
175 |
+
current_structure_idx += 1
|
176 |
+
current_sentence_in_structure_idx = 1
|
177 |
+
|
178 |
+
elif token in self.sentence_split_token_ids: # utterance split token
|
179 |
+
# only update structure token, leave content and sentence index token null (default 0)
|
180 |
+
# add up sentence index
|
181 |
+
if self.mode == 'sum':
|
182 |
+
special_tokens[i] = token
|
183 |
+
else:
|
184 |
+
special_tokens[i] = self.sentence_split_token_ids.index(token) + 1 + len(self.structure_token_ids)
|
185 |
+
current_sentence_in_structure_idx += 1
|
186 |
+
|
187 |
+
elif token in self.structure_split_token_ids: # structure split token
|
188 |
+
# update structure token (current structure), content token (current token),
|
189 |
+
# blank index token
|
190 |
+
if self.mode == 'sum':
|
191 |
+
special_tokens[i] = token
|
192 |
+
else:
|
193 |
+
special_tokens[i] = self.structure_split_token_ids.index(token) + 1 + len(self.structure_token_ids) + len(self.sentence_split_token_ids)
|
194 |
+
|
195 |
+
else: # content tokens
|
196 |
+
content_tokens[i] = token
|
197 |
+
special_tokens[i] = current_structure
|
198 |
+
sentence_idx_in_structure_tokens[i] = min(current_sentence_in_structure_idx, self.max_sentence_per_structure - 1)
|
199 |
+
|
200 |
+
# print("tokens", tokens.max(), tokens.min())
|
201 |
+
# print("special tokens", special_tokens.max(), special_tokens.min())
|
202 |
+
# print("sentence idx in structure", sentence_idx_in_structure_tokens.max(), sentence_idx_in_structure_tokens.min())
|
203 |
+
device = self.output_proj.weight.device
|
204 |
+
|
205 |
+
# import pdb; pdb.set_trace()
|
206 |
+
content_embeds = self.output_proj(tokens.to(device)) # [T, N]
|
207 |
+
if self.mode == 'sum':
|
208 |
+
structure_embeds = self.output_proj(special_tokens.to(device))
|
209 |
+
else:
|
210 |
+
structure_embeds = self.special_emb(special_tokens.to(device))
|
211 |
+
sentence_idx_embeds = self.sentence_idx_in_structure_emb(sentence_idx_in_structure_tokens.to(device))
|
212 |
+
|
213 |
+
if self.mode == 'sum':
|
214 |
+
embeds = content_embeds + structure_embeds + sentence_idx_embeds
|
215 |
+
else:
|
216 |
+
embeds = torch.cat((content_embeds, structure_embeds, sentence_idx_embeds), -1) # [T, N]
|
217 |
+
|
218 |
+
if self.interpolate:
|
219 |
+
embeds = self.interpolate_with_structure_duration(tokens, embeds, structure_dur[b])
|
220 |
+
embeds_batch.append(embeds)
|
221 |
+
|
222 |
+
# set batch_size = 1, [B, T, N]
|
223 |
+
if self.max_len is not None:
|
224 |
+
max_len = self.max_len
|
225 |
+
else:
|
226 |
+
max_len = max([e.shape[0] for e in embeds_batch])
|
227 |
+
embeds, mask = self.pad_2d_tensor(embeds_batch, max_len)
|
228 |
+
|
229 |
+
return embeds, mask
|
230 |
+
|
231 |
+
|
232 |
+
def pad_2d_tensor(self, xs, max_len):
|
233 |
+
new_tensor = []
|
234 |
+
new_mask = []
|
235 |
+
for x in xs:
|
236 |
+
seq_len, dim = x.size()
|
237 |
+
pad_len = max_len - seq_len
|
238 |
+
|
239 |
+
if pad_len > 0:
|
240 |
+
pad_tensor = self.blank_emb.repeat(pad_len, 1).to(x.device) # T, D
|
241 |
+
padded_tensor = torch.cat([x, pad_tensor], dim=0)
|
242 |
+
mask = torch.cat((torch.ones_like(x[:, 0]),
|
243 |
+
torch.zeros_like(pad_tensor[:, 0])), 0) # T
|
244 |
+
elif pad_len < 0:
|
245 |
+
padded_tensor = x[:max_len]
|
246 |
+
mask = torch.ones_like(padded_tensor[:, 0])
|
247 |
+
else:
|
248 |
+
padded_tensor = x
|
249 |
+
mask = torch.ones_like(x[:, 0])
|
250 |
+
|
251 |
+
new_tensor.append(padded_tensor)
|
252 |
+
new_mask.append(mask)
|
253 |
+
# [B, T, D] & [B, T]
|
254 |
+
return torch.stack(new_tensor, 0), torch.stack(new_mask, 0)
|
SongBloom/models/musicgen/conditioners/wav.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from .base import *
|
3 |
+
import omegaconf
|
4 |
+
from ...vae_frontend import AbstractVAE
|
5 |
+
|
6 |
+
def pad_to_fix_length(x, max_len, pad_value=0.):
|
7 |
+
bsz, seq_len = x.shape[:2]
|
8 |
+
if seq_len >= max_len:
|
9 |
+
return x[:, :max_len]
|
10 |
+
else:
|
11 |
+
pad_len = max_len - seq_len
|
12 |
+
pad_tensor = torch.full((bsz, pad_len, *x.shape[2:]), pad_value, dtype=x.dtype, device=x.device)
|
13 |
+
padded_tensor = torch.cat([x, pad_tensor], dim=1)
|
14 |
+
return padded_tensor
|
15 |
+
|
16 |
+
class AudioTokenizerConditioner(WaveformConditioner):
|
17 |
+
def __init__(self, output_dim, audio_tokenizer, cache=False, max_len=None):
|
18 |
+
super().__init__(output_dim, output_dim)
|
19 |
+
self.max_len = max_len
|
20 |
+
self.use_cache = cache
|
21 |
+
|
22 |
+
self.tokenizer = audio_tokenizer
|
23 |
+
# breakpoint()
|
24 |
+
|
25 |
+
# TODO if cached and not load vae, receive a dict instead
|
26 |
+
if isinstance(self.tokenizer, dict):
|
27 |
+
self.tokenizer = omegaconf.DictConfig(self.tokenizer)
|
28 |
+
self.code_depth = self.tokenizer.channel_dim
|
29 |
+
|
30 |
+
|
31 |
+
elif isinstance(self.tokenizer, AbstractVAE):
|
32 |
+
self.tokenizer_tp = "vae"
|
33 |
+
if self.use_cache:
|
34 |
+
self.code_depth = self.tokenizer.channel_dim
|
35 |
+
else:
|
36 |
+
self.code_depth = 1 # TODO 强制把输入channel设成1了 self.tokenizer.input_channel
|
37 |
+
self.output_proj = nn.Identity() if self.output_dim == self.tokenizer.channel_dim \
|
38 |
+
else nn.Linear(self.tokenizer.channel_dim, self.output_dim, bias=False)
|
39 |
+
|
40 |
+
else:
|
41 |
+
raise NotImplementedError
|
42 |
+
|
43 |
+
|
44 |
+
def forward(self, x: WavCondition):
|
45 |
+
wav, lengths, *_ = x
|
46 |
+
B = wav.shape[0]
|
47 |
+
wav = wav.reshape(B, self.code_depth, -1)
|
48 |
+
# print(wav.shape)
|
49 |
+
# import torchaudio
|
50 |
+
# torchaudio.save("/apdcephfs_cq7/share_1297902/common/erichtchen/shixisheng/cyy/project/music_generation_repo/core/models/musicgen/conditioners/111.wav", wav[0].cpu(), 48000)
|
51 |
+
if self.tokenizer_tp == "vae":
|
52 |
+
if self.use_cache:
|
53 |
+
audio_latents = wav.transpose(-1,-2)
|
54 |
+
else:
|
55 |
+
with torch.no_grad():
|
56 |
+
audio_latents = self.tokenizer.encode(wav).transpose(-1,-2)
|
57 |
+
# print('transform wav to vae')
|
58 |
+
audio_latents = self.output_proj(audio_latents)
|
59 |
+
|
60 |
+
# print(audio_latents.shape)
|
61 |
+
if self.max_len is not None:
|
62 |
+
audio_latents = pad_to_fix_length(audio_latents, self.max_len, 0.)
|
63 |
+
|
64 |
+
if lengths is not None:
|
65 |
+
lengths = torch.round(lengths.float() * audio_latents.shape[1] / wav.shape[-1])
|
66 |
+
mask = length_to_mask(lengths, max_len=audio_latents.shape[1]).int() # type: ignore
|
67 |
+
else:
|
68 |
+
mask = torch.ones((B, audio_latents.shape[1]), device=audio_latents.device,dtype=torch.int)
|
69 |
+
|
70 |
+
audio_latents = audio_latents * mask[..., None]
|
71 |
+
|
72 |
+
return audio_latents, mask
|
73 |
+
|
74 |
+
|
SongBloom/models/musicgen/get_backend.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import os,sys
|
3 |
+
from transformers.utils import is_flash_attn_2_available
|
4 |
+
from transformers.models.llama import LlamaModel, LlamaConfig
|
5 |
+
from transformers.models.bart.modeling_bart import BartEncoder, BartDecoder, BartConfig
|
6 |
+
import warnings
|
7 |
+
# from transformers.models.musicgen.modeling_musicgen import MusicgenModel, MusicgenDecoder, MusicgenDecoderConfig # 用的就是BartDecoder,但是没有cross-attn
|
8 |
+
|
9 |
+
try:
|
10 |
+
assert is_flash_attn_2_available()
|
11 |
+
assert torch.cuda.get_device_capability(torch.device("cuda")) >= (8, 0)
|
12 |
+
assert os.environ.get("DISABLE_FLASH_ATTN",'0') != "1"
|
13 |
+
_enable_flash_attention = True
|
14 |
+
except:
|
15 |
+
_enable_flash_attention = False
|
16 |
+
|
17 |
+
if not _enable_flash_attention:
|
18 |
+
warnings.warn("Not support flash-attn!")
|
19 |
+
|
20 |
+
def get_backend(name, dim, num_heads, num_layers, hidden_scale, init_std=0.02, rope_theta=10000,):
|
21 |
+
# SA (causal) - FF
|
22 |
+
if name == 'llama':
|
23 |
+
model_cfg = LlamaConfig(
|
24 |
+
hidden_size=dim,
|
25 |
+
intermediate_size=dim * hidden_scale,
|
26 |
+
num_attention_heads=num_heads,
|
27 |
+
num_hidden_layers=num_layers,
|
28 |
+
num_key_value_heads=num_heads,
|
29 |
+
vocab_size=dim,
|
30 |
+
use_cache=False,
|
31 |
+
max_position_embeddings=4096,
|
32 |
+
hidden_act="silu",
|
33 |
+
initializer_range=init_std,
|
34 |
+
rope_theta=rope_theta,
|
35 |
+
_attn_implementation="flash_attention_2" if _enable_flash_attention else "eager",
|
36 |
+
)
|
37 |
+
model = LlamaModel(model_cfg)
|
38 |
+
|
39 |
+
# SA -FF
|
40 |
+
elif name == 'bart_enc':
|
41 |
+
model_cfg = BartConfig(
|
42 |
+
d_model=dim,
|
43 |
+
max_position_embeddings=4096,
|
44 |
+
dropout=0.,
|
45 |
+
use_cache=False,
|
46 |
+
_attn_implementation="flash_attention_2" if _enable_flash_attention else "eager",
|
47 |
+
activation_function='gelu',
|
48 |
+
# for BartEncoder
|
49 |
+
encoder_layers=num_layers,
|
50 |
+
encoder_attention_heads=num_heads,
|
51 |
+
init_std=init_std,
|
52 |
+
encoder_ffn_dim=dim * hidden_scale,
|
53 |
+
)
|
54 |
+
model = BartEncoder(model_cfg)
|
55 |
+
|
56 |
+
# SA - CA - FF
|
57 |
+
elif name == 'bart_dec':
|
58 |
+
model_cfg = BartConfig(
|
59 |
+
d_model=dim,
|
60 |
+
max_position_embeddings=4096,
|
61 |
+
dropout=0.,
|
62 |
+
use_cache=False,
|
63 |
+
_attn_implementation="flash_attention_2" if _enable_flash_attention else "eager",
|
64 |
+
activation_function='gelu',
|
65 |
+
# for BartDecoder
|
66 |
+
decoder_layers=num_layers,
|
67 |
+
decoder_attention_heads=num_heads,
|
68 |
+
decoder_ffn_dim=dim * hidden_scale,
|
69 |
+
)
|
70 |
+
model = BartDecoder(model_cfg)
|
71 |
+
|
72 |
+
else:
|
73 |
+
raise NotImplementedError
|
74 |
+
|
75 |
+
delattr(model, "embed_tokens")
|
76 |
+
return model
|
SongBloom/models/musicgen/modules/streaming.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Streaming module API that should be implemented by all Streaming components,
|
3 |
+
"""
|
4 |
+
|
5 |
+
from contextlib import contextmanager
|
6 |
+
import typing as tp
|
7 |
+
from torch import nn
|
8 |
+
import torch
|
9 |
+
|
10 |
+
|
11 |
+
State = tp.Dict[str, torch.Tensor]
|
12 |
+
|
13 |
+
|
14 |
+
class StreamingModule(nn.Module):
|
15 |
+
"""Common API for streaming components.
|
16 |
+
|
17 |
+
Each streaming component has a streaming state, which is just a dict[str, Tensor].
|
18 |
+
By convention, the first dim of each tensor must be the batch size.
|
19 |
+
Don't use dots in the key names, as this would clash with submodules
|
20 |
+
(like in state_dict).
|
21 |
+
|
22 |
+
If `self._is_streaming` is True, the component should use and remember
|
23 |
+
the proper state inside `self._streaming_state`.
|
24 |
+
|
25 |
+
To set a streaming component in streaming state, use
|
26 |
+
|
27 |
+
with module.streaming():
|
28 |
+
...
|
29 |
+
|
30 |
+
This will automatically reset the streaming state when exiting the context manager.
|
31 |
+
This also automatically propagates to all streaming children module.
|
32 |
+
|
33 |
+
Some module might also implement the `StreamingModule.flush` method, although
|
34 |
+
this one is trickier, as all parents module must be StreamingModule and implement
|
35 |
+
it as well for it to work properly. See `StreamingSequential` after.
|
36 |
+
"""
|
37 |
+
def __init__(self) -> None:
|
38 |
+
super().__init__()
|
39 |
+
self._streaming_state: State = {}
|
40 |
+
self._is_streaming = False
|
41 |
+
|
42 |
+
def _apply_named_streaming(self, fn: tp.Any):
|
43 |
+
for name, module in self.named_modules():
|
44 |
+
if isinstance(module, StreamingModule):
|
45 |
+
fn(name, module)
|
46 |
+
|
47 |
+
def _set_streaming(self, streaming: bool):
|
48 |
+
def _set_streaming(name, module):
|
49 |
+
module._is_streaming = streaming
|
50 |
+
self._apply_named_streaming(_set_streaming)
|
51 |
+
|
52 |
+
@contextmanager
|
53 |
+
def streaming(self):
|
54 |
+
"""Context manager to enter streaming mode. Reset streaming state on exit."""
|
55 |
+
self._set_streaming(True)
|
56 |
+
try:
|
57 |
+
yield
|
58 |
+
finally:
|
59 |
+
self._set_streaming(False)
|
60 |
+
self.reset_streaming()
|
61 |
+
|
62 |
+
def reset_streaming(self):
|
63 |
+
"""Reset the streaming state."""
|
64 |
+
def _reset(name: str, module: StreamingModule):
|
65 |
+
module._streaming_state.clear()
|
66 |
+
|
67 |
+
self._apply_named_streaming(_reset)
|
68 |
+
|
69 |
+
def get_streaming_state(self) -> State:
|
70 |
+
"""Return the streaming state, including that of sub-modules."""
|
71 |
+
state: State = {}
|
72 |
+
|
73 |
+
def _add(name: str, module: StreamingModule):
|
74 |
+
if name:
|
75 |
+
name += "."
|
76 |
+
for key, value in module._streaming_state.items():
|
77 |
+
state[name + key] = value
|
78 |
+
|
79 |
+
self._apply_named_streaming(_add)
|
80 |
+
return state
|
81 |
+
|
82 |
+
def set_streaming_state(self, state: State):
|
83 |
+
"""Set the streaming state, including that of sub-modules."""
|
84 |
+
state = dict(state)
|
85 |
+
|
86 |
+
def _set(name: str, module: StreamingModule):
|
87 |
+
if name:
|
88 |
+
name += "."
|
89 |
+
module._streaming_state.clear()
|
90 |
+
for key, value in list(state.items()):
|
91 |
+
# complexity is not ideal here, but probably fine.
|
92 |
+
if key.startswith(name):
|
93 |
+
local_key = key[len(name):]
|
94 |
+
if '.' not in local_key:
|
95 |
+
module._streaming_state[local_key] = value
|
96 |
+
del state[key]
|
97 |
+
|
98 |
+
self._apply_named_streaming(_set)
|
99 |
+
assert len(state) == 0, list(state.keys())
|
100 |
+
|
101 |
+
def flush(self, x: tp.Optional[torch.Tensor] = None):
|
102 |
+
"""Flush any remaining outputs that were waiting for completion.
|
103 |
+
Typically, for convolutions, this will add the final padding
|
104 |
+
and process the last buffer.
|
105 |
+
|
106 |
+
This should take an optional argument `x`, which will be provided
|
107 |
+
if a module before this one in the streaming pipeline has already
|
108 |
+
spitted out a flushed out buffer.
|
109 |
+
"""
|
110 |
+
if x is None:
|
111 |
+
return None
|
112 |
+
else:
|
113 |
+
return self(x)
|
114 |
+
|
115 |
+
|
116 |
+
class StreamingSequential(StreamingModule, nn.Sequential):
|
117 |
+
"""A streaming compatible alternative of `nn.Sequential`.
|
118 |
+
"""
|
119 |
+
def flush(self, x: tp.Optional[torch.Tensor] = None):
|
120 |
+
for module in self:
|
121 |
+
if isinstance(module, StreamingModule):
|
122 |
+
x = module.flush(x)
|
123 |
+
elif x is not None:
|
124 |
+
x = module(x)
|
125 |
+
return x
|
SongBloom/models/musicldm/__init__.py
ADDED
File without changes
|
SongBloom/models/musicldm/inference/__init__.py
ADDED
File without changes
|
SongBloom/models/musicldm/inference/sampling.py
ADDED
@@ -0,0 +1,271 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import math
|
3 |
+
from tqdm import trange, tqdm
|
4 |
+
|
5 |
+
# import k_diffusion as K
|
6 |
+
|
7 |
+
# Define the noise schedule and sampling loop
|
8 |
+
def get_alphas_sigmas(t):
|
9 |
+
"""Returns the scaling factors for the clean image (alpha) and for the
|
10 |
+
noise (sigma), given a timestep."""
|
11 |
+
return torch.cos(t * math.pi / 2), torch.sin(t * math.pi / 2)
|
12 |
+
|
13 |
+
def alpha_sigma_to_t(alpha, sigma):
|
14 |
+
"""Returns a timestep, given the scaling factors for the clean image and for
|
15 |
+
the noise."""
|
16 |
+
return torch.atan2(sigma, alpha) / math.pi * 2
|
17 |
+
|
18 |
+
def t_to_alpha_sigma(t):
|
19 |
+
"""Returns the scaling factors for the clean image and for the noise, given
|
20 |
+
a timestep."""
|
21 |
+
return torch.cos(t * math.pi / 2), torch.sin(t * math.pi / 2)
|
22 |
+
|
23 |
+
|
24 |
+
@torch.no_grad()
|
25 |
+
def sample_discrete_euler(model, x, steps, sigma_max=1.0, prog_bar=False, **extra_args):
|
26 |
+
"""Draws samples from a model given starting noise. Euler method"""
|
27 |
+
|
28 |
+
# Make tensor of ones to broadcast the single t values
|
29 |
+
ts = x.new_ones([x.shape[0]])
|
30 |
+
|
31 |
+
# Create the noise schedule
|
32 |
+
t = torch.linspace(sigma_max, 0, steps + 1)
|
33 |
+
# all = {}
|
34 |
+
|
35 |
+
#alphas, sigmas = 1-t, t
|
36 |
+
iterator = tqdm(zip(t[:-1], t[1:]), total=steps) if prog_bar else zip(t[:-1], t[1:])
|
37 |
+
for t_curr, t_prev in iterator:
|
38 |
+
# Broadcast the current timestep to the correct shape
|
39 |
+
t_curr_tensor = t_curr * torch.ones(
|
40 |
+
(x.shape[0],), dtype=x.dtype, device=x.device
|
41 |
+
)
|
42 |
+
dt = t_prev - t_curr # we solve backwards in our formulation
|
43 |
+
v = model(x, t_curr_tensor, **extra_args)
|
44 |
+
# all[t_curr.item()] = x-t_curr*v
|
45 |
+
x = x + dt * v #.denoise(x, denoiser, t_curr_tensor, cond, uc)
|
46 |
+
|
47 |
+
# If we are on the last timestep, output the denoised image
|
48 |
+
return x #, all
|
49 |
+
|
50 |
+
@torch.no_grad()
|
51 |
+
def sample_discrete_euler_with_temperature(model, x, steps, temperature=1.0, sigma_max=1.0, prog_bar=False, **extra_args):
|
52 |
+
"""Draws samples from a model given starting noise. Euler method"""
|
53 |
+
|
54 |
+
# Make tensor of ones to broadcast the single t values
|
55 |
+
ts = x.new_ones([x.shape[0]])
|
56 |
+
noise = x
|
57 |
+
|
58 |
+
# Create the noise schedule
|
59 |
+
t = torch.linspace(sigma_max, 0, steps + 1)
|
60 |
+
# all = {}
|
61 |
+
x = torch.zeros_like(noise)
|
62 |
+
if temperature >= sigma_max:
|
63 |
+
x = noise
|
64 |
+
|
65 |
+
#alphas, sigmas = 1-t, t
|
66 |
+
iterator = tqdm(zip(t[:-1], t[1:]), total=steps) if prog_bar else zip(t[:-1], t[1:])
|
67 |
+
for t_curr, t_prev in iterator:
|
68 |
+
# Broadcast the current timestep to the correct shape
|
69 |
+
|
70 |
+
t_curr_tensor = t_curr * torch.ones(
|
71 |
+
(x.shape[0],), dtype=x.dtype, device=x.device
|
72 |
+
)
|
73 |
+
dt = t_prev - t_curr # we solve backwards in our formulation
|
74 |
+
v = model(x, t_curr_tensor, **extra_args)
|
75 |
+
# all[t_curr.item()] = x-t_curr*v
|
76 |
+
if t_curr > temperature and t_prev <= temperature:
|
77 |
+
x_0 = x - v
|
78 |
+
x = (1-t_prev) * x_0 + t_prev * noise
|
79 |
+
else:
|
80 |
+
x = x + dt * v #.denoise(x, denoiser, t_curr_tensor, cond, uc)
|
81 |
+
|
82 |
+
# If we are on the last timestep, output the denoised image
|
83 |
+
return x #, all
|
84 |
+
|
85 |
+
|
86 |
+
@torch.no_grad()
|
87 |
+
def sample(model, x, steps, eta, prog_bar=False, **extra_args):
|
88 |
+
"""Draws samples from a model given starting noise. v-diffusion"""
|
89 |
+
ts = x.new_ones([x.shape[0]])
|
90 |
+
origin_dtype = x.dtype
|
91 |
+
# Create the noise schedule
|
92 |
+
t = torch.linspace(1, 0, steps + 1)[:-1]
|
93 |
+
|
94 |
+
alphas, sigmas = get_alphas_sigmas(t)
|
95 |
+
|
96 |
+
# The sampling loop
|
97 |
+
bar = trange if prog_bar else range
|
98 |
+
for i in bar(steps):
|
99 |
+
|
100 |
+
# Get the model output (v, the predicted velocity)
|
101 |
+
with torch.cuda.amp.autocast():
|
102 |
+
v = model(x, ts * t[i], **extra_args).float()
|
103 |
+
|
104 |
+
# Predict the noise and the denoised image
|
105 |
+
pred = x * alphas[i] - v * sigmas[i]
|
106 |
+
eps = x * sigmas[i] + v * alphas[i]
|
107 |
+
|
108 |
+
# If we are not on the last timestep, compute the noisy image for the
|
109 |
+
# next timestep.
|
110 |
+
if i < steps - 1:
|
111 |
+
# If eta > 0, adjust the scaling factor for the predicted noise
|
112 |
+
# downward according to the amount of additional noise to add
|
113 |
+
ddim_sigma = eta * (sigmas[i + 1]**2 / sigmas[i]**2).sqrt() * \
|
114 |
+
(1 - alphas[i]**2 / alphas[i + 1]**2).sqrt()
|
115 |
+
adjusted_sigma = (sigmas[i + 1]**2 - ddim_sigma**2).sqrt()
|
116 |
+
|
117 |
+
# Recombine the predicted noise and predicted denoised image in the
|
118 |
+
# correct proportions for the next step
|
119 |
+
x = pred * alphas[i + 1] + eps * adjusted_sigma
|
120 |
+
# Add the correct amount of fresh noise
|
121 |
+
if eta:
|
122 |
+
x += torch.randn_like(x) * ddim_sigma
|
123 |
+
|
124 |
+
# If we are on the last timestep, output the denoised image
|
125 |
+
return pred.to(origin_dtype)
|
126 |
+
|
127 |
+
# Soft mask inpainting is just shrinking hard (binary) mask inpainting
|
128 |
+
# Given a float-valued soft mask (values between 0 and 1), get the binary mask for this particular step
|
129 |
+
def get_bmask(i, steps, mask):
|
130 |
+
strength = (i+1)/(steps)
|
131 |
+
# convert to binary mask
|
132 |
+
bmask = torch.where(mask<=strength,1,0)
|
133 |
+
return bmask
|
134 |
+
|
135 |
+
def make_cond_model_fn(model, cond_fn):
|
136 |
+
def cond_model_fn(x, sigma, **kwargs):
|
137 |
+
with torch.enable_grad():
|
138 |
+
x = x.detach().requires_grad_()
|
139 |
+
denoised = model(x, sigma, **kwargs)
|
140 |
+
cond_grad = cond_fn(x, sigma, denoised=denoised, **kwargs).detach()
|
141 |
+
cond_denoised = denoised.detach() + cond_grad * K.utils.append_dims(sigma**2, x.ndim)
|
142 |
+
return cond_denoised
|
143 |
+
return cond_model_fn
|
144 |
+
|
145 |
+
# Uses k-diffusion from https://github.com/crowsonkb/k-diffusion
|
146 |
+
# init_data is init_audio as latents (if this is latent diffusion)
|
147 |
+
# For sampling, set both init_data and mask to None
|
148 |
+
# For variations, set init_data
|
149 |
+
# For inpainting, set both init_data & mask
|
150 |
+
def sample_k(
|
151 |
+
model_fn,
|
152 |
+
noise,
|
153 |
+
init_data=None,
|
154 |
+
mask=None,
|
155 |
+
steps=100,
|
156 |
+
sampler_type="dpmpp-2m-sde",
|
157 |
+
sigma_min=0.5,
|
158 |
+
sigma_max=50,
|
159 |
+
rho=1.0, device="cuda",
|
160 |
+
callback=None,
|
161 |
+
cond_fn=None,
|
162 |
+
**extra_args
|
163 |
+
):
|
164 |
+
|
165 |
+
denoiser = K.external.VDenoiser(model_fn)
|
166 |
+
|
167 |
+
if cond_fn is not None:
|
168 |
+
denoiser = make_cond_model_fn(denoiser, cond_fn)
|
169 |
+
|
170 |
+
# Make the list of sigmas. Sigma values are scalars related to the amount of noise each denoising step has
|
171 |
+
sigmas = K.sampling.get_sigmas_polyexponential(steps, sigma_min, sigma_max, rho, device=device)
|
172 |
+
# Scale the initial noise by sigma
|
173 |
+
noise = noise * sigmas[0]
|
174 |
+
|
175 |
+
wrapped_callback = callback
|
176 |
+
|
177 |
+
if mask is None and init_data is not None:
|
178 |
+
# VARIATION (no inpainting)
|
179 |
+
# set the initial latent to the init_data, and noise it with initial sigma
|
180 |
+
x = init_data + noise
|
181 |
+
elif mask is not None and init_data is not None:
|
182 |
+
# INPAINTING
|
183 |
+
bmask = get_bmask(0, steps, mask)
|
184 |
+
# initial noising
|
185 |
+
input_noised = init_data + noise
|
186 |
+
# set the initial latent to a mix of init_data and noise, based on step 0's binary mask
|
187 |
+
x = input_noised * bmask + noise * (1-bmask)
|
188 |
+
# define the inpainting callback function (Note: side effects, it mutates x)
|
189 |
+
# See https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/sampling.py#L596C13-L596C105
|
190 |
+
# callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
191 |
+
# This is called immediately after `denoised = model(x, sigmas[i] * s_in, **extra_args)`
|
192 |
+
def inpainting_callback(args):
|
193 |
+
i = args["i"]
|
194 |
+
x = args["x"]
|
195 |
+
sigma = args["sigma"]
|
196 |
+
#denoised = args["denoised"]
|
197 |
+
# noise the init_data input with this step's appropriate amount of noise
|
198 |
+
input_noised = init_data + torch.randn_like(init_data) * sigma
|
199 |
+
# shrinking hard mask
|
200 |
+
bmask = get_bmask(i, steps, mask)
|
201 |
+
# mix input_noise with x, using binary mask
|
202 |
+
new_x = input_noised * bmask + x * (1-bmask)
|
203 |
+
# mutate x
|
204 |
+
x[:,:,:] = new_x[:,:,:]
|
205 |
+
# wrap together the inpainting callback and the user-submitted callback.
|
206 |
+
if callback is None:
|
207 |
+
wrapped_callback = inpainting_callback
|
208 |
+
else:
|
209 |
+
wrapped_callback = lambda args: (inpainting_callback(args), callback(args))
|
210 |
+
else:
|
211 |
+
# SAMPLING
|
212 |
+
# set the initial latent to noise
|
213 |
+
x = noise
|
214 |
+
|
215 |
+
|
216 |
+
with torch.cuda.amp.autocast():
|
217 |
+
if sampler_type == "k-heun":
|
218 |
+
return K.sampling.sample_heun(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args)
|
219 |
+
elif sampler_type == "k-lms":
|
220 |
+
return K.sampling.sample_lms(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args)
|
221 |
+
elif sampler_type == "k-dpmpp-2s-ancestral":
|
222 |
+
return K.sampling.sample_dpmpp_2s_ancestral(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args)
|
223 |
+
elif sampler_type == "k-dpm-2":
|
224 |
+
return K.sampling.sample_dpm_2(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args)
|
225 |
+
elif sampler_type == "k-dpm-fast":
|
226 |
+
return K.sampling.sample_dpm_fast(denoiser, x, sigma_min, sigma_max, steps, disable=False, callback=wrapped_callback, extra_args=extra_args)
|
227 |
+
elif sampler_type == "k-dpm-adaptive":
|
228 |
+
return K.sampling.sample_dpm_adaptive(denoiser, x, sigma_min, sigma_max, rtol=0.01, atol=0.01, disable=False, callback=wrapped_callback, extra_args=extra_args)
|
229 |
+
elif sampler_type == "dpmpp-2m-sde":
|
230 |
+
return K.sampling.sample_dpmpp_2m_sde(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args)
|
231 |
+
elif sampler_type == "dpmpp-3m-sde":
|
232 |
+
return K.sampling.sample_dpmpp_3m_sde(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args)
|
233 |
+
|
234 |
+
# Uses discrete Euler sampling for rectified flow models
|
235 |
+
# init_data is init_audio as latents (if this is latent diffusion)
|
236 |
+
# For sampling, set both init_data and mask to None
|
237 |
+
# For variations, set init_data
|
238 |
+
# For inpainting, set both init_data & mask
|
239 |
+
def sample_rf(
|
240 |
+
model_fn,
|
241 |
+
noise,
|
242 |
+
init_data=None,
|
243 |
+
steps=100,
|
244 |
+
sigma_max=1,
|
245 |
+
device="cuda",
|
246 |
+
callback=None,
|
247 |
+
cond_fn=None,
|
248 |
+
**extra_args
|
249 |
+
):
|
250 |
+
|
251 |
+
if sigma_max > 1:
|
252 |
+
sigma_max = 1
|
253 |
+
|
254 |
+
if cond_fn is not None:
|
255 |
+
denoiser = make_cond_model_fn(denoiser, cond_fn)
|
256 |
+
|
257 |
+
wrapped_callback = callback
|
258 |
+
|
259 |
+
if init_data is not None:
|
260 |
+
# VARIATION (no inpainting)
|
261 |
+
# Interpolate the init data and the noise for init audio
|
262 |
+
x = init_data * (1 - sigma_max) + noise * sigma_max
|
263 |
+
else:
|
264 |
+
# SAMPLING
|
265 |
+
# set the initial latent to noise
|
266 |
+
x = noise
|
267 |
+
|
268 |
+
with torch.cuda.amp.autocast():
|
269 |
+
# TODO: Add callback support
|
270 |
+
#return sample_discrete_euler(model_fn, x, steps, sigma_max, callback=wrapped_callback, **extra_args)
|
271 |
+
return sample_discrete_euler(model_fn, x, steps, sigma_max, **extra_args)
|
SongBloom/models/musicldm/musicldm_dit.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from contextlib import contextmanager
|
2 |
+
from dataclasses import dataclass
|
3 |
+
from functools import partial
|
4 |
+
import logging
|
5 |
+
import math
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from torch import nn
|
9 |
+
|
10 |
+
|
11 |
+
|
12 |
+
class FourierFeatures(nn.Module):
|
13 |
+
def __init__(self, in_features, out_features, std=1.):
|
14 |
+
super().__init__()
|
15 |
+
assert out_features % 2 == 0
|
16 |
+
self.weight = nn.Parameter(torch.randn(
|
17 |
+
[out_features // 2, in_features]) * std)
|
18 |
+
|
19 |
+
def forward(self, input):
|
20 |
+
f = 2 * math.pi * input @ self.weight.T
|
21 |
+
return torch.cat([f.cos(), f.sin()], dim=-1)
|
22 |
+
|
23 |
+
|
24 |
+
|
SongBloom/models/songbloom/songbloom_mvsa.py
ADDED
@@ -0,0 +1,572 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from contextlib import contextmanager
|
2 |
+
from dataclasses import dataclass
|
3 |
+
from functools import partial
|
4 |
+
import logging
|
5 |
+
import math
|
6 |
+
import typing as tp
|
7 |
+
|
8 |
+
import torch
|
9 |
+
from torch import nn
|
10 |
+
import torch.nn.functional as F
|
11 |
+
from einops.layers.torch import Rearrange
|
12 |
+
from einops import rearrange
|
13 |
+
import tqdm
|
14 |
+
|
15 |
+
from ..base.utils import create_norm_fn
|
16 |
+
from ..base.sample import sample_top_k, sample_top_p, multinomial
|
17 |
+
from ..musicgen.modules.streaming import StreamingModule
|
18 |
+
from ..musicgen.conditioners import (
|
19 |
+
get_condition_fuser,
|
20 |
+
get_conditioner_provider,
|
21 |
+
ConditionType,
|
22 |
+
ConditioningProvider,
|
23 |
+
ConditionFuser,
|
24 |
+
AttributeDropout,
|
25 |
+
ClassifierFreeGuidanceDropout,
|
26 |
+
ConditioningAttributes,
|
27 |
+
WavCondition,
|
28 |
+
JointEmbedCondition
|
29 |
+
)
|
30 |
+
|
31 |
+
from ..musicgen.get_backend import get_backend
|
32 |
+
|
33 |
+
from ..transformer import ContinuousTransformer as DiT_block
|
34 |
+
from ..musicldm.musicldm_dit import FourierFeatures
|
35 |
+
from ..musicldm.inference.sampling import get_alphas_sigmas, sample, sample_discrete_euler, sample_discrete_euler_with_temperature
|
36 |
+
|
37 |
+
ConditionTensors = tp.Dict[str, ConditionType]
|
38 |
+
CFGConditions = tp.Union[ConditionTensors, tp.Tuple[ConditionTensors, ConditionTensors]]
|
39 |
+
|
40 |
+
@dataclass
|
41 |
+
class DiTAROutput:
|
42 |
+
ar_logit: torch.Tensor
|
43 |
+
ar_target: torch.Tensor
|
44 |
+
nar_pred: torch.Tensor
|
45 |
+
nar_target: torch.Tensor
|
46 |
+
nar_t: torch.Tensor
|
47 |
+
|
48 |
+
|
49 |
+
|
50 |
+
class MVSA_DiTAR(StreamingModule):
|
51 |
+
"""
|
52 |
+
Multiple skeleton embedding, single compressed vae latent
|
53 |
+
eg. V1 V2 V3 A1-3 V4 V5 V6 A4-6
|
54 |
+
V -> cross entropy (skeleton)
|
55 |
+
A -> local-DiT uncompress -> (A1-3 -> E1 E2 E3)
|
56 |
+
|
57 |
+
Args:
|
58 |
+
StreamingModule (_type_): _description_
|
59 |
+
"""
|
60 |
+
|
61 |
+
def __init__(self, condition_provider_cfg, fuser_cfg,
|
62 |
+
block_size: int = 32, dim: int = 1024, num_heads: int = 8,
|
63 |
+
num_pitch: int = 128, hidden_scale: int = 4, lm_layers: int = 16,
|
64 |
+
norm: str = 'layer_norm', pre_norm: bool = False,
|
65 |
+
backend='llama',init_std: float=0.02,
|
66 |
+
# ======================
|
67 |
+
latent_dim: int = 64, diff_layers: int = 8,
|
68 |
+
time_cond_type: tp.Literal['adaLM', "prepend"] = "prepend",
|
69 |
+
timestep_features_dim: int = 256,
|
70 |
+
diffusion_objective: tp.Literal["v", "rectified_flow"] = "v",
|
71 |
+
timestep_sampler: tp.Literal["uniform", "logit_normal"] = "uniform",
|
72 |
+
rotary_base_val=10000, h_dropout: float = None,
|
73 |
+
# ======================
|
74 |
+
cfg_dropout: float = 0, cfg_coef: float = 1.0,
|
75 |
+
attribute_dropout: tp.Dict[str, tp.Dict[str, float]] = {}
|
76 |
+
):
|
77 |
+
super().__init__()
|
78 |
+
|
79 |
+
self.condition_provider = get_conditioner_provider(condition_provider_cfg)
|
80 |
+
self.fuser = get_condition_fuser(fuser_cfg)
|
81 |
+
|
82 |
+
self.dim = dim
|
83 |
+
self.latent_dim = latent_dim
|
84 |
+
self.block_size = block_size
|
85 |
+
|
86 |
+
self.cfg_coef = cfg_coef
|
87 |
+
self.h_dropout = h_dropout if h_dropout is not None else 0.
|
88 |
+
self.cfg_dropout = ClassifierFreeGuidanceDropout(p=cfg_dropout)
|
89 |
+
self.att_dropout = AttributeDropout(p=attribute_dropout)
|
90 |
+
|
91 |
+
|
92 |
+
# Build AR lm
|
93 |
+
self.num_pitch = num_pitch + 1 # self.num_pitch = <EOS>, self.num_pitch+1 = special
|
94 |
+
self.skeleton_emb = nn.Embedding(self.num_pitch + 1, dim)
|
95 |
+
self.bos_token = nn.Parameter(torch.empty(dim).normal_(mean=0.0, std=init_std), requires_grad=True)
|
96 |
+
|
97 |
+
|
98 |
+
# self.lm_type = lm_type
|
99 |
+
self.backend = backend
|
100 |
+
|
101 |
+
if self.backend == 'llama':
|
102 |
+
self.ar_transformer = get_backend('llama',
|
103 |
+
dim, num_heads, lm_layers, hidden_scale,init_std=init_std, rope_theta=rotary_base_val)
|
104 |
+
self.ar_transformer.gradient_checkpointing_enable()
|
105 |
+
elif self.backend == 'bart':
|
106 |
+
self.cross_encoder = get_backend('bart_enc',
|
107 |
+
dim, num_heads, lm_layers // 4, hidden_scale,init_std=init_std)
|
108 |
+
self.ar_transformer = get_backend('bart_dec',
|
109 |
+
dim, num_heads, lm_layers, hidden_scale,init_std=init_std)
|
110 |
+
else:
|
111 |
+
raise NotImplementedError(f"Illegal backend: {self.backend}!")
|
112 |
+
|
113 |
+
|
114 |
+
self.skeleton_classifier = nn.Sequential(nn.Linear(dim, dim, bias=False),
|
115 |
+
nn.SiLU(),
|
116 |
+
nn.Linear(dim, self.num_pitch),)
|
117 |
+
|
118 |
+
self.pre_norm: tp.Optional[nn.Module] = None
|
119 |
+
if pre_norm:
|
120 |
+
self.pre_norm = create_norm_fn(norm, dim)
|
121 |
+
self.reset_streaming()
|
122 |
+
|
123 |
+
# Build NAR DiT
|
124 |
+
self.block_conv = nn.Sequential(
|
125 |
+
Rearrange("b d (n s) -> b n (s d)", s=self.block_size),
|
126 |
+
nn.Linear(self.block_size * latent_dim, dim),
|
127 |
+
nn.SiLU(),
|
128 |
+
nn.Linear(dim, dim)
|
129 |
+
)
|
130 |
+
self.project_in = nn.Linear(latent_dim, dim) if latent_dim != dim else nn.Identity()
|
131 |
+
self.project_out = nn.Linear(dim, latent_dim) if latent_dim != dim else nn.Identity()
|
132 |
+
|
133 |
+
self.timestep_features_dim = timestep_features_dim
|
134 |
+
self.time_cond_type = time_cond_type
|
135 |
+
assert self.time_cond_type in ['adaLN', "prepend"]
|
136 |
+
self.timestep_features = FourierFeatures(1, timestep_features_dim)
|
137 |
+
self.to_timestep_embed = nn.Sequential(
|
138 |
+
nn.Linear(timestep_features_dim, dim, bias=False),
|
139 |
+
nn.SiLU(),
|
140 |
+
nn.Linear(dim, dim),
|
141 |
+
)
|
142 |
+
|
143 |
+
self.time_cond_type = time_cond_type
|
144 |
+
self.nar_dit = DiT_block(
|
145 |
+
dim=dim,
|
146 |
+
depth=diff_layers,
|
147 |
+
dim_heads= dim // num_heads,
|
148 |
+
rotary_pos_emb=True,
|
149 |
+
cross_attend=False,
|
150 |
+
causal=False,
|
151 |
+
ff_kwargs={"dim_ff": dim * hidden_scale, "no_bias": True},
|
152 |
+
global_cond_dim=self.dim if self.time_cond_type=="adaLN" else None,
|
153 |
+
rotary_base_val = rotary_base_val,
|
154 |
+
# init_std=init_std
|
155 |
+
)
|
156 |
+
self.nar_dit.gradient_checkpointing_enable()
|
157 |
+
|
158 |
+
self.diffusion_objective = diffusion_objective
|
159 |
+
self.timestep_sampler = timestep_sampler
|
160 |
+
self.rng = torch.quasirandom.SobolEngine(1, scramble=True)
|
161 |
+
|
162 |
+
self.init_weights(init_std=init_std)
|
163 |
+
|
164 |
+
|
165 |
+
|
166 |
+
|
167 |
+
@property
|
168 |
+
def special_token_id(self) -> int:
|
169 |
+
return self.num_pitch
|
170 |
+
|
171 |
+
|
172 |
+
@property
|
173 |
+
def eos_token_id(self) -> int:
|
174 |
+
return self.num_pitch-1
|
175 |
+
|
176 |
+
def forward(self, x_sketch, x_latent, x_len, condition_tensors) -> DiTAROutput:
|
177 |
+
'''
|
178 |
+
only for train: lm_forward + diffusion_forward (random_t)
|
179 |
+
x_sketch: (B,T) # T % block_sz == 0 (no <eos> token) padded with <eos>
|
180 |
+
x_latent: (B, D_{in}, T)
|
181 |
+
'''
|
182 |
+
# AR
|
183 |
+
assert torch.all(x_len % self.block_size == 0), f"{x_len}"
|
184 |
+
block_num = x_len // self.block_size
|
185 |
+
|
186 |
+
sketch_emb = self.skeleton_emb(x_sketch)
|
187 |
+
latent_emb = self.block_conv(x_latent)
|
188 |
+
|
189 |
+
B, T, D = sketch_emb.shape
|
190 |
+
|
191 |
+
lm_input = rearrange(torch.cat([rearrange(sketch_emb, "b (n s) d -> b n s d", s=self.block_size),
|
192 |
+
latent_emb.unsqueeze(dim=2)], dim=2), "b n s d -> b (n s) d")
|
193 |
+
lm_input = torch.cat([self.bos_token.reshape(1,1,-1).expand(B,-1,-1),
|
194 |
+
lm_input], dim=1) #add <sos>
|
195 |
+
|
196 |
+
new_seq_len = x_len + block_num + 1
|
197 |
+
|
198 |
+
ar_target = F.pad(x_sketch, (0,1), value=self.eos_token_id)
|
199 |
+
for b,l in enumerate(x_len):
|
200 |
+
ar_target[b, l+1:] = self.special_token_id # 用来mask掉多余的eos
|
201 |
+
|
202 |
+
|
203 |
+
|
204 |
+
lm_out = self.lm_forward(lm_input, condition_tensors)
|
205 |
+
|
206 |
+
|
207 |
+
indices = torch.arange(lm_out.shape[1])
|
208 |
+
h_ind = indices[(indices+1) % (self.block_size+1) == 0]
|
209 |
+
not_h_ind = indices[(indices+1) % (self.block_size+1) != 0]
|
210 |
+
|
211 |
+
x_sketch_logit = self.skeleton_classifier(lm_out[:, not_h_ind])
|
212 |
+
|
213 |
+
# NAR (h + prev_block)
|
214 |
+
h_pad = lm_out[:, h_ind] # B, N, D
|
215 |
+
h = torch.cat([hh[:hl] for hh, hl in zip(h_pad, block_num)], dim=0)
|
216 |
+
block_semantic = rearrange(sketch_emb, "b (n s) d -> b n s d", s=self.block_size) # B, N, 32, D
|
217 |
+
current_block_semantic = torch.cat([bb[:bl] for bb, bl in zip(block_semantic, block_num)], dim=0)
|
218 |
+
|
219 |
+
if self.training: # for CFG
|
220 |
+
drop_h_idx = torch.rand((h.shape[0], 1), device=h.device) < self.h_dropout
|
221 |
+
h = torch.masked_fill(h, drop_h_idx, 0)
|
222 |
+
# current_block_semantic = torch.masked_fill(current_block_semantic, drop_h_idx.unsqueeze(-1), 0)
|
223 |
+
|
224 |
+
drop_s_idx = torch.rand((current_block_semantic.shape[0], 1), device=current_block_semantic.device) < self.h_dropout
|
225 |
+
current_block_semantic = torch.masked_fill(current_block_semantic, drop_s_idx.unsqueeze(-1), 0)
|
226 |
+
|
227 |
+
with torch.no_grad():
|
228 |
+
block_latent = rearrange(x_latent, "b d (n s) -> b n s d", s=self.block_size) # B, N, 32, D
|
229 |
+
current_block = torch.cat([bb[:bl] for bb, bl in zip(block_latent, block_num)], dim=0)
|
230 |
+
prev_block = torch.cat([bb[:bl] for bb, bl in zip(F.pad(block_latent, (0,0,0,0,1,0)), block_num)], dim=0)
|
231 |
+
|
232 |
+
# b_indices = torch.randperm(block_latent.shape[0])[:B*16]
|
233 |
+
# h, current_block, prev_block = h[b_indices], current_block[b_indices], prev_block[b_indices]
|
234 |
+
|
235 |
+
orig_type = x_latent.dtype
|
236 |
+
with torch.cuda.amp.autocast(enabled=False):
|
237 |
+
if self.timestep_sampler == "uniform":
|
238 |
+
# Draw uniformly distributed continuous timesteps
|
239 |
+
t = self.rng.draw(h.shape[0])[:, 0].to(device=h.device, dtype=h.dtype)
|
240 |
+
elif self.timestep_sampler == "logit_normal":
|
241 |
+
t = torch.sigmoid(torch.randn(h.shape[0], device=h.device, dtype=h.dtype))
|
242 |
+
elif self.timestep_sampler == "trunc_logit_normal":
|
243 |
+
# Draw from logistic truncated normal distribution
|
244 |
+
from ..musicldm.musicldm_pl import truncated_logistic_normal_rescaled
|
245 |
+
t = truncated_logistic_normal_rescaled(h.shape[0]).to(h.device)
|
246 |
+
# Flip the distribution
|
247 |
+
t = 1 - t
|
248 |
+
|
249 |
+
# Calculate the noise schedule parameters for those timesteps
|
250 |
+
if self.diffusion_objective == "v":
|
251 |
+
alphas, sigmas = get_alphas_sigmas(t)
|
252 |
+
elif self.diffusion_objective == "rectified_flow":
|
253 |
+
alphas, sigmas = 1-t, t
|
254 |
+
# Combine the ground truth data and the noise
|
255 |
+
alphas = alphas[:, None, None]
|
256 |
+
sigmas = sigmas[:, None, None]
|
257 |
+
noise = torch.randn_like(current_block)
|
258 |
+
noised_inputs = current_block * alphas + noise * sigmas
|
259 |
+
if self.diffusion_objective == "v": # (a_t - a_{t-1})x_0 + (b_t-b_{t-1}) e = -b x_0 + a e
|
260 |
+
targets = noise * alphas - current_block * sigmas
|
261 |
+
elif self.diffusion_objective == "rectified_flow": #||(XT-X0) - p(x_t, t)||
|
262 |
+
targets = noise - current_block
|
263 |
+
|
264 |
+
nar_output = self.diffusion_forward(noised_inputs.to(orig_type), t.to(orig_type), h, current_block_semantic, prev_block)
|
265 |
+
|
266 |
+
return DiTAROutput(
|
267 |
+
ar_logit=x_sketch_logit,
|
268 |
+
ar_target=ar_target,
|
269 |
+
nar_pred=nar_output,
|
270 |
+
nar_target=targets.to(orig_type),
|
271 |
+
nar_t=t
|
272 |
+
)
|
273 |
+
|
274 |
+
|
275 |
+
|
276 |
+
def lm_forward(self, sequence, condition_tensors: tp.Optional[ConditionTensors] = None) -> torch.Tensor:
|
277 |
+
# import pdb; pdb.set_trace()
|
278 |
+
B, T, D = sequence.shape
|
279 |
+
if self.pre_norm:
|
280 |
+
sequence = self.pre_norm(sequence.to(self.pre_norm.weight.data.dtype))
|
281 |
+
|
282 |
+
input_, cross_attention_input = self.fuser(sequence, condition_tensors)
|
283 |
+
|
284 |
+
transformer_input = {
|
285 |
+
"inputs_embeds":input_,
|
286 |
+
"use_cache": self._is_streaming,
|
287 |
+
"past_key_values": self._streaming_state.get('past_key_values', None),
|
288 |
+
}
|
289 |
+
if self.backend == 'bart': # TODO infer 的时候这个玩意不用重复算
|
290 |
+
# TODO attention_mask
|
291 |
+
cross_attention_input = self.cross_encoder(inputs_embeds=cross_attention_input)
|
292 |
+
transformer_input["encoder_hidden_states"] = cross_attention_input.last_hidden_state
|
293 |
+
|
294 |
+
output = self.ar_transformer(**transformer_input)
|
295 |
+
if self._is_streaming:
|
296 |
+
self._streaming_state['past_key_values'] = output.past_key_values
|
297 |
+
out = output.last_hidden_state
|
298 |
+
|
299 |
+
|
300 |
+
|
301 |
+
if len(self.fuser.fuse2cond['prepend']) > 0:
|
302 |
+
out = out[:, -T:, :]
|
303 |
+
|
304 |
+
return out
|
305 |
+
|
306 |
+
|
307 |
+
|
308 |
+
|
309 |
+
def diffusion_forward(self,
|
310 |
+
x: torch.Tensor,
|
311 |
+
t: torch.Tensor, # B,
|
312 |
+
h: torch.Tensor,
|
313 |
+
s: torch.Tensor, # B, self.block_size, D
|
314 |
+
history_x: torch.Tensor,
|
315 |
+
cfg_coef: float = None) -> torch.Tensor:
|
316 |
+
|
317 |
+
if cfg_coef is not None:
|
318 |
+
# only for infer
|
319 |
+
assert not self.training # only for inference
|
320 |
+
x = torch.cat([x,x], dim=0)
|
321 |
+
t = torch.cat([t,t], dim=0)
|
322 |
+
h = torch.cat([h,torch.zeros_like(h)], dim=0)
|
323 |
+
s = torch.cat([s,torch.zeros_like(s)], dim=0)
|
324 |
+
history_x = torch.cat([history_x,history_x], dim=0)
|
325 |
+
|
326 |
+
B, T, _ = x.shape
|
327 |
+
|
328 |
+
input_ = self.project_in(torch.cat([history_x, x], dim=1))
|
329 |
+
# print(h.shape, s.shape, input_.shape)
|
330 |
+
input_ = torch.cat([h.unsqueeze(1), s, input_], dim=1)
|
331 |
+
# Get the batch of timestep embeddings
|
332 |
+
timestep_embed = self.to_timestep_embed(self.timestep_features(t[:, None]))# (b, embed_dim)
|
333 |
+
# breakpoint()
|
334 |
+
if self.time_cond_type == "prepend":
|
335 |
+
input_ = torch.cat([timestep_embed.unsqueeze(1), input_], dim=1)
|
336 |
+
|
337 |
+
transformer_input = {
|
338 |
+
"x": input_,
|
339 |
+
"global_cond": timestep_embed if self.time_cond_type == "adaLN" else None}
|
340 |
+
|
341 |
+
output = self.nar_dit(**transformer_input)
|
342 |
+
|
343 |
+
# remove the prefix from the model outputs
|
344 |
+
output = output[:, -T:, :]
|
345 |
+
output = self.project_out(output)
|
346 |
+
|
347 |
+
if cfg_coef is not None:
|
348 |
+
cond_output, uncond_output = torch.chunk(output, 2, dim=0)
|
349 |
+
output = uncond_output + (cond_output - uncond_output) * cfg_coef
|
350 |
+
|
351 |
+
return output # [B, T, D]
|
352 |
+
|
353 |
+
|
354 |
+
|
355 |
+
def _sample_next_block(self,
|
356 |
+
sequence: torch.Tensor,
|
357 |
+
prev_latents: torch.Tensor,
|
358 |
+
condition_tensors: tp.Optional[ConditionTensors] = None,
|
359 |
+
cfg_coef: tp.Optional[tp.Union[float, tp.List[float]]] = None,
|
360 |
+
steps: int = 50,
|
361 |
+
dit_cfg_type: str = 'h',
|
362 |
+
use_sampling: bool = False,
|
363 |
+
temp: float = 1.0,
|
364 |
+
diff_temp: float = 1.0,
|
365 |
+
top_k: int = 0,
|
366 |
+
top_p: float = 0.0,
|
367 |
+
penalty_token_pool: tp.Optional[list] = None) -> torch.Tensor:
|
368 |
+
# infer: lm next_token -> (if % block_sz == 0) infer diff
|
369 |
+
# 1. sample sketch (lm) -> 2. sample latent (lm+diff)
|
370 |
+
sequence = sequence.clone()
|
371 |
+
|
372 |
+
if isinstance(cfg_coef, tp.Iterable):
|
373 |
+
assert len(cfg_coef) == 2
|
374 |
+
cfg_coef_lm, cfg_coef_diff = cfg_coef
|
375 |
+
else:
|
376 |
+
cfg_coef_lm, cfg_coef_diff = cfg_coef, cfg_coef
|
377 |
+
|
378 |
+
B = sequence.shape[0]
|
379 |
+
# import pdb; pdb.set_trace()
|
380 |
+
|
381 |
+
if condition_tensors:
|
382 |
+
# Preparing for CFG, predicting both conditional and unconditional logits.
|
383 |
+
sequence = torch.cat([sequence, sequence], dim=0)
|
384 |
+
|
385 |
+
|
386 |
+
# ############### decode sketch #########################
|
387 |
+
next_tokens = []
|
388 |
+
next_token_embs = []
|
389 |
+
|
390 |
+
for k in range(self.block_size):
|
391 |
+
if self._is_streaming and k > 0:
|
392 |
+
lm_inp = sequence[:,-1:]
|
393 |
+
else:
|
394 |
+
lm_inp = sequence
|
395 |
+
|
396 |
+
lm_out = self.lm_forward(
|
397 |
+
lm_inp,
|
398 |
+
condition_tensors=condition_tensors)
|
399 |
+
next_pitch_logit = self.skeleton_classifier(lm_out[:, -1:]) # B, 1, card
|
400 |
+
|
401 |
+
if condition_tensors:
|
402 |
+
cond_logit, uncond_logit = next_pitch_logit.split(B, dim=0)
|
403 |
+
next_pitch_logit = uncond_logit + (cond_logit - uncond_logit) * cfg_coef_lm
|
404 |
+
|
405 |
+
# add penalty to pre-sampled tokens
|
406 |
+
if penalty_token_pool is not None and len(penalty_token_pool) > 0: # B, T
|
407 |
+
for b in range(B):
|
408 |
+
# q_count = torch.bincount(penalty_token_pool)
|
409 |
+
q_count = torch.bincount(torch.unique(penalty_token_pool[b]))
|
410 |
+
tmp = min(q_count.shape[-1], self.num_pitch - 1)
|
411 |
+
next_pitch_logit[b, -1, :tmp] /= (1.1 ** q_count[:tmp])
|
412 |
+
|
413 |
+
# sample k
|
414 |
+
if use_sampling and temp > 0.0:
|
415 |
+
probs = torch.softmax(next_pitch_logit / temp, dim=-1)
|
416 |
+
if top_p > 0.0:
|
417 |
+
next_token = sample_top_p(probs, p=top_p)
|
418 |
+
elif top_k > 0:
|
419 |
+
next_token = sample_top_k(probs, k=top_k)
|
420 |
+
else:
|
421 |
+
next_token = multinomial(probs, num_samples=1)
|
422 |
+
next_token = next_token.squeeze(-1)
|
423 |
+
else:
|
424 |
+
next_token = torch.argmax(next_pitch_logit, dim=-1) # B, 1
|
425 |
+
if penalty_token_pool is not None and len(penalty_token_pool) > 0: # B, T
|
426 |
+
penalty_token_pool = torch.cat([penalty_token_pool, next_token], dim=-1)[:,1:]
|
427 |
+
next_token_emb = self.skeleton_emb(next_token) #B, 1, d
|
428 |
+
next_tokens.append(next_token)
|
429 |
+
next_token_embs.append(next_token_emb)
|
430 |
+
|
431 |
+
if condition_tensors:
|
432 |
+
doubled_next_emb = torch.cat([next_token_emb, next_token_emb], dim=0)
|
433 |
+
sequence = torch.cat([sequence, doubled_next_emb], dim=1)
|
434 |
+
else:
|
435 |
+
sequence = torch.cat([sequence, next_token_emb], dim=1)
|
436 |
+
|
437 |
+
next_tokens = torch.cat(next_tokens, dim=1)
|
438 |
+
next_token_embs = torch.cat(next_token_embs, dim=1)
|
439 |
+
|
440 |
+
# ############### decode latent ###########################
|
441 |
+
# 这里求h虽然double了 但是没用classifier-free guidance
|
442 |
+
if self._is_streaming:
|
443 |
+
lm_inp = sequence[:,-1:]
|
444 |
+
else:
|
445 |
+
lm_inp = sequence
|
446 |
+
|
447 |
+
lm_out = self.lm_forward(
|
448 |
+
lm_inp,
|
449 |
+
condition_tensors=condition_tensors)
|
450 |
+
|
451 |
+
h = lm_out[:,-1]
|
452 |
+
|
453 |
+
noise = torch.randn((B, self.block_size, self.latent_dim), device=h.device, dtype=h.dtype)
|
454 |
+
|
455 |
+
assert dit_cfg_type in ['h', 'global', 'none']
|
456 |
+
"""
|
457 |
+
global: same cfg setting as next-token-prediction
|
458 |
+
none: no cfg
|
459 |
+
h: no cfg during ar-stage and apply cfg via ar output
|
460 |
+
"""
|
461 |
+
if condition_tensors:
|
462 |
+
if dit_cfg_type == 'global':
|
463 |
+
noise = torch.cat([noise, noise], dim=0)
|
464 |
+
prev_latents = torch.cat([prev_latents, prev_latents], dim=0)
|
465 |
+
semantic_embs = torch.cat([next_token_embs, next_token_embs], dim=0)
|
466 |
+
else:
|
467 |
+
h, _ = h.chunk(2, dim=0)
|
468 |
+
semantic_embs = next_token_embs
|
469 |
+
|
470 |
+
|
471 |
+
if self.diffusion_objective == "v":
|
472 |
+
next_latent = sample(self.diffusion_forward, noise, steps=steps, eta=0, h=h, s=semantic_embs, history_x=prev_latents,
|
473 |
+
cfg_coef=(cfg_coef_diff if dit_cfg_type=='h' else None))
|
474 |
+
elif self.diffusion_objective == "rectified_flow":
|
475 |
+
# next_latent = sample_discrete_euler(self.diffusion_forward, noise, steps=steps, h=h, s=semantic_embs, history_x=prev_latents,
|
476 |
+
# cfg_coef=(cfg_coef_diff if dit_cfg_type=='h' else None))
|
477 |
+
next_latent = sample_discrete_euler_with_temperature(self.diffusion_forward, noise, steps=steps, temperature=diff_temp, h=h, s=semantic_embs, history_x=prev_latents,
|
478 |
+
cfg_coef=(cfg_coef_diff if dit_cfg_type=='h' else None))
|
479 |
+
if condition_tensors and dit_cfg_type == 'global':
|
480 |
+
cond_next_latent, uncond_next_latent = torch.chunk(next_latent, 2, dim=0)
|
481 |
+
next_latent = uncond_next_latent + (cond_next_latent - uncond_next_latent) * cfg_coef_diff
|
482 |
+
|
483 |
+
latent_emb = self.block_conv(next_latent.transpose(1,2))
|
484 |
+
|
485 |
+
next_block_seq = torch.cat([next_token_embs, latent_emb], dim=1) # B, self.block_size+1, d
|
486 |
+
|
487 |
+
return next_tokens, next_latent, next_block_seq
|
488 |
+
|
489 |
+
|
490 |
+
|
491 |
+
@torch.no_grad()
|
492 |
+
def generate(self,
|
493 |
+
prompt: tp.Optional[torch.Tensor] = None,
|
494 |
+
conditions: tp.List[ConditioningAttributes] = [],
|
495 |
+
cfg_coef: tp.Optional[tp.Union[float, tp.List[float]]] = None,
|
496 |
+
steps=50,
|
497 |
+
dit_cfg_type: str = 'h',
|
498 |
+
max_frames: int = 1500, # 60 * 25
|
499 |
+
use_sampling: bool = True,
|
500 |
+
temp: float = 1.0,
|
501 |
+
diff_temp: float = 1.0,
|
502 |
+
top_k: int = 0,
|
503 |
+
top_p: float = 0.0,
|
504 |
+
penalty_repeat: bool = False,
|
505 |
+
penalty_window: int = 50) -> torch.Tensor:
|
506 |
+
assert not self.training, "generation shouldn't be used in training mode."
|
507 |
+
|
508 |
+
B = len(conditions)
|
509 |
+
assert B==1, "currently do not support batch decoding"
|
510 |
+
null_conditions = ClassifierFreeGuidanceDropout(p=1.0)(conditions)
|
511 |
+
conditions = conditions + null_conditions
|
512 |
+
tokenized = self.condition_provider.tokenize(conditions)
|
513 |
+
condition_tensors = self.condition_provider(tokenized)
|
514 |
+
|
515 |
+
|
516 |
+
sequence = self.bos_token.reshape(1,1,-1).expand(B, 1, -1)
|
517 |
+
if prompt is not None:
|
518 |
+
# TODO
|
519 |
+
raise NotImplementedError
|
520 |
+
# sequence = torch.cat([sequence, prompt])
|
521 |
+
|
522 |
+
|
523 |
+
prev_blocks = torch.zeros((B, self.block_size, self.latent_dim), device=sequence.device, dtype=sequence.dtype)
|
524 |
+
latent_seq, token_seq = None, None
|
525 |
+
|
526 |
+
with self.streaming():
|
527 |
+
prog_bar = tqdm.tqdm()
|
528 |
+
while True:
|
529 |
+
if token_seq is None or not penalty_repeat:
|
530 |
+
penalty_token_pool = None
|
531 |
+
else:
|
532 |
+
penalty_token_pool = token_seq[: ,-penalty_window:]
|
533 |
+
if penalty_token_pool.shape[-1] < penalty_window:
|
534 |
+
penalty_token_pool = F.pad(penalty_token_pool, (penalty_window - penalty_token_pool.shape[-1], 0), value=self.eos_token_id)
|
535 |
+
next_tokens, next_latent, next_block_seq = self._sample_next_block(sequence[:, -1: ], prev_blocks, condition_tensors,
|
536 |
+
cfg_coef=cfg_coef, steps=steps, dit_cfg_type=dit_cfg_type,
|
537 |
+
use_sampling=use_sampling, temp=temp, diff_temp=diff_temp,
|
538 |
+
top_k=top_k, top_p=top_p,
|
539 |
+
penalty_token_pool=penalty_token_pool)
|
540 |
+
|
541 |
+
if (next_tokens == self.eos_token_id).any() or sequence.shape[1] > max_frames / self.block_size * (self.block_size+1):
|
542 |
+
break
|
543 |
+
|
544 |
+
latent_seq = next_latent if latent_seq is None else torch.cat([latent_seq, next_latent], dim=1) # B,T, D
|
545 |
+
token_seq = next_tokens if token_seq is None else torch.cat([token_seq, next_tokens], dim=1) # B,T
|
546 |
+
sequence = torch.cat([sequence, next_block_seq], dim=1)
|
547 |
+
prev_blocks = next_latent
|
548 |
+
|
549 |
+
prog_bar.update(self.block_size)
|
550 |
+
|
551 |
+
|
552 |
+
if latent_seq is None:
|
553 |
+
latent_seq = prev_blocks
|
554 |
+
return latent_seq.transpose(1,2), token_seq
|
555 |
+
|
556 |
+
|
557 |
+
|
558 |
+
|
559 |
+
def init_weights(self, init_std=0.02):
|
560 |
+
|
561 |
+
def _init_weights(module, init_std=0.02):
|
562 |
+
if isinstance(module, nn.Linear):
|
563 |
+
module.weight.data.normal_(mean=0.0, std=init_std)
|
564 |
+
# torch.nn.init.xavier_uniform_(module.weight)
|
565 |
+
if module.bias is not None:
|
566 |
+
nn.init.constant_(module.bias, 0)
|
567 |
+
elif isinstance(module, nn.Embedding):
|
568 |
+
module.weight.data.normal_(mean=0.0, std=init_std)
|
569 |
+
if module.padding_idx is not None:
|
570 |
+
module.weight.data[module.padding_idx].zero_()
|
571 |
+
|
572 |
+
self.apply(partial(_init_weights, init_std=init_std))
|
SongBloom/models/songbloom/songbloom_pl.py
ADDED
@@ -0,0 +1,224 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from functools import partial
|
3 |
+
import typing as tp
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
from torch.nn import functional as F
|
7 |
+
import torchaudio
|
8 |
+
import numpy as np
|
9 |
+
import random
|
10 |
+
from omegaconf import OmegaConf
|
11 |
+
import copy
|
12 |
+
import lightning as pl
|
13 |
+
|
14 |
+
import os, sys
|
15 |
+
|
16 |
+
from ..musicgen.conditioners import WavCondition, JointEmbedCondition, ConditioningAttributes
|
17 |
+
from ..vae_frontend import StableVAE
|
18 |
+
from .songbloom_mvsa import MVSA_DiTAR
|
19 |
+
from ...g2p.lyric_common import key2processor, symbols, LABELS
|
20 |
+
|
21 |
+
|
22 |
+
os.environ['TOKENIZERS_PARALLELISM'] = "false"
|
23 |
+
|
24 |
+
|
25 |
+
class SongBloom_PL(pl.LightningModule):
|
26 |
+
def __init__(self, cfg):
|
27 |
+
super().__init__()
|
28 |
+
# 关闭自动优化
|
29 |
+
# self.automatic_optimization = False
|
30 |
+
|
31 |
+
self.cfg = cfg
|
32 |
+
|
33 |
+
# Build VAE
|
34 |
+
self.vae = StableVAE(**cfg.vae).eval()
|
35 |
+
assert self.cfg.model['latent_dim'] == self.vae.channel_dim
|
36 |
+
|
37 |
+
|
38 |
+
self.save_hyperparameters(cfg)
|
39 |
+
if self.vae is not None:
|
40 |
+
for param in self.vae.parameters():
|
41 |
+
param.requires_grad = False
|
42 |
+
|
43 |
+
# Build DiT
|
44 |
+
model_cfg = OmegaConf.to_container(copy.deepcopy(cfg.model), resolve=True)
|
45 |
+
for cond_name in model_cfg["condition_provider_cfg"]:
|
46 |
+
if model_cfg["condition_provider_cfg"][cond_name]['type'] == 'audio_tokenizer_wrapper':
|
47 |
+
model_cfg["condition_provider_cfg"][cond_name]["audio_tokenizer"] = self.vae
|
48 |
+
model_cfg["condition_provider_cfg"][cond_name]["cache"] = False
|
49 |
+
|
50 |
+
|
51 |
+
self.model = MVSA_DiTAR(**model_cfg)
|
52 |
+
# print(self.model)
|
53 |
+
|
54 |
+
|
55 |
+
|
56 |
+
|
57 |
+
|
58 |
+
|
59 |
+
####################################
|
60 |
+
|
61 |
+
class SongBloom_Sampler:
|
62 |
+
|
63 |
+
def __init__(self, compression_model: StableVAE, diffusion: MVSA_DiTAR, lyric_processor_key,
|
64 |
+
max_duration: float, prompt_duration: tp.Optional[float] = None):
|
65 |
+
self.compression_model = compression_model
|
66 |
+
self.diffusion = diffusion
|
67 |
+
self.lyric_processor_key = lyric_processor_key
|
68 |
+
self.lyric_processor = key2processor.get(lyric_processor_key) if lyric_processor_key is not None else lambda x: x
|
69 |
+
# import pdb; pdb.set_trace()
|
70 |
+
|
71 |
+
assert max_duration is not None
|
72 |
+
self.max_duration: float = max_duration
|
73 |
+
self.prompt_duration = prompt_duration
|
74 |
+
|
75 |
+
|
76 |
+
self.device = next(iter(diffusion.parameters())).device
|
77 |
+
self.generation_params: dict = {}
|
78 |
+
# self.set_generation_params(duration=15) # 15 seconds by default
|
79 |
+
self.set_generation_params(cfg_coef=1.5, steps=50, dit_cfg_type='h',
|
80 |
+
use_sampling=True, top_k=200, max_frames=self.max_duration * 25)
|
81 |
+
self._progress_callback: tp.Optional[tp.Callable[[int, int], None]] = None
|
82 |
+
|
83 |
+
@classmethod
|
84 |
+
def build_from_trainer(cls, cfg, strict=True, dtype=torch.float32):
|
85 |
+
model_light = SongBloom_PL(cfg)
|
86 |
+
incompatible = model_light.load_state_dict(torch.load(cfg.pretrained_path, map_location='cpu'), strict=strict)
|
87 |
+
|
88 |
+
lyric_processor_key = cfg.train_dataset.lyric_processor
|
89 |
+
|
90 |
+
print(incompatible)
|
91 |
+
|
92 |
+
model_light = model_light.eval().cuda().to(dtype=dtype)
|
93 |
+
model = cls(
|
94 |
+
compression_model = model_light.vae,
|
95 |
+
diffusion = model_light.model,
|
96 |
+
lyric_processor_key = lyric_processor_key,
|
97 |
+
max_duration = cfg.max_dur,
|
98 |
+
prompt_duration = cfg.sr * cfg.train_dataset.prompt_len
|
99 |
+
|
100 |
+
)
|
101 |
+
model.set_generation_params(**cfg.inference)
|
102 |
+
return model
|
103 |
+
|
104 |
+
@property
|
105 |
+
def frame_rate(self) -> float:
|
106 |
+
"""Roughly the number of AR steps per seconds."""
|
107 |
+
return self.compression_model.frame_rate
|
108 |
+
|
109 |
+
@property
|
110 |
+
def sample_rate(self) -> int:
|
111 |
+
"""Sample rate of the generated audio."""
|
112 |
+
return self.compression_model.sample_rate
|
113 |
+
|
114 |
+
|
115 |
+
def set_generation_params(self, **kwargs):
|
116 |
+
"""Set the generation parameters."""
|
117 |
+
self.generation_params.update(kwargs)
|
118 |
+
|
119 |
+
# Mulan Inference
|
120 |
+
@torch.no_grad()
|
121 |
+
def generate(self, lyrics, prompt_wav) -> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, torch.Tensor]]:
|
122 |
+
""" Generate samples conditioned on text and melody.
|
123 |
+
"""
|
124 |
+
# breakpoint()
|
125 |
+
assert prompt_wav.ndim == 2
|
126 |
+
if self.prompt_duration is not None:
|
127 |
+
prompt_wav = prompt_wav[..., :self.prompt_duration]
|
128 |
+
|
129 |
+
attributes, _ = self._prepare_tokens_and_attributes(conditions={"lyrics": [self._process_lyric(lyrics)], "prompt_wav": [prompt_wav]},
|
130 |
+
prompt=None, prompt_tokens=None)
|
131 |
+
|
132 |
+
# breakpoint()
|
133 |
+
print(self.generation_params)
|
134 |
+
latent_seq, token_seq = self.diffusion.generate(None, attributes, **self.generation_params)
|
135 |
+
# print(token_seq)
|
136 |
+
audio_recon = self.compression_model.decode(latent_seq).float()
|
137 |
+
|
138 |
+
return audio_recon
|
139 |
+
|
140 |
+
|
141 |
+
def _process_lyric(self, input_lyric):
|
142 |
+
if self.lyric_processor_key == 'pinyin':
|
143 |
+
processed_lyric = self.lyric_processor(input_lyric)
|
144 |
+
else:
|
145 |
+
processed_lyric = []
|
146 |
+
check_lyric = input_lyric.split(" ")
|
147 |
+
for ii in range(len(check_lyric)):
|
148 |
+
if check_lyric[ii] not in symbols and check_lyric[ii] not in LABELS.keys() and len(check_lyric[ii]) > 0:
|
149 |
+
new = self.lyric_processor(check_lyric[ii])
|
150 |
+
check_lyric[ii] = new
|
151 |
+
processed_lyric = " ".join(check_lyric)
|
152 |
+
|
153 |
+
return processed_lyric
|
154 |
+
|
155 |
+
@torch.no_grad()
|
156 |
+
def _prepare_tokens_and_attributes(
|
157 |
+
self,
|
158 |
+
conditions: tp.Dict[str, tp.List[tp.Union[str, torch.Tensor]]],
|
159 |
+
prompt: tp.Optional[torch.Tensor],
|
160 |
+
prompt_tokens: tp.Optional[torch.Tensor] = None,
|
161 |
+
) -> tp.Tuple[tp.List[ConditioningAttributes], tp.Optional[torch.Tensor]]:
|
162 |
+
"""Prepare model inputs.
|
163 |
+
|
164 |
+
Args:
|
165 |
+
descriptions (list of str): A list of strings used as text conditioning.
|
166 |
+
prompt (torch.Tensor): A batch of waveforms used for continuation.
|
167 |
+
melody_wavs (torch.Tensor, optional): A batch of waveforms
|
168 |
+
used as melody conditioning. Defaults to None.
|
169 |
+
"""
|
170 |
+
batch_size = len(list(conditions.values())[0])
|
171 |
+
assert batch_size == 1
|
172 |
+
# breakpoint()
|
173 |
+
attributes = [ConditioningAttributes() for _ in range(batch_size)]
|
174 |
+
for k in self.diffusion.condition_provider.conditioners:
|
175 |
+
conds = conditions.pop(k, [None for _ in attributes])
|
176 |
+
for attr, cond in zip(attributes, conds):
|
177 |
+
if self.diffusion.condition_provider.conditioner_type[k] == 'wav':
|
178 |
+
if cond is None:
|
179 |
+
attr.wav[k] = WavCondition(
|
180 |
+
torch.zeros((1, 1, 1), device=self.device),
|
181 |
+
torch.tensor([0], device=self.device).long(),
|
182 |
+
sample_rate=[self.sample_rate],
|
183 |
+
path=[None])
|
184 |
+
else:
|
185 |
+
attr.wav[k] = WavCondition(
|
186 |
+
cond.to(device=self.device).unsqueeze(0), # 1,C,T .mean(dim=0, keepdim=True)
|
187 |
+
torch.tensor([cond.shape[-1]], device=self.device).long(),
|
188 |
+
sample_rate=[self.sample_rate],
|
189 |
+
path=[None])
|
190 |
+
elif self.diffusion.condition_provider.conditioner_type[k] == 'text':
|
191 |
+
attr.text[k] = cond
|
192 |
+
elif self.diffusion.condition_provider.conditioner_type[k] == 'joint_embed':
|
193 |
+
if cond is None or isinstance(cond, str):
|
194 |
+
attr.joint_embed[k] = JointEmbedCondition(
|
195 |
+
torch.zeros((1, 1, 1), device=self.device),
|
196 |
+
[cond],
|
197 |
+
torch.tensor([0], device=self.device).long(),
|
198 |
+
sample_rate=[self.sample_rate],
|
199 |
+
path=[None])
|
200 |
+
elif isinstance(cond, torch.Tensor):
|
201 |
+
attr.joint_embed[k] = JointEmbedCondition(
|
202 |
+
cond.to(device=self.device).mean(dim=0, keepdim=True).unsqueeze(0),
|
203 |
+
[None],
|
204 |
+
torch.tensor([cond.shape[-1]], device=self.device).long(),
|
205 |
+
sample_rate=[self.sample_rate],
|
206 |
+
path=[None])
|
207 |
+
else:
|
208 |
+
raise NotImplementedError
|
209 |
+
assert conditions == {}, f"Find illegal conditions: {conditions}, support keys: {self.lm.condition_provider.conditioners}"
|
210 |
+
# breakpoint()
|
211 |
+
print(attributes)
|
212 |
+
|
213 |
+
if prompt_tokens is not None:
|
214 |
+
prompt_tokens = prompt_tokens.to(self.device)
|
215 |
+
assert prompt is None
|
216 |
+
elif prompt is not None:
|
217 |
+
assert len(attributes) == len(prompt), "Prompt and nb. attributes doesn't match"
|
218 |
+
prompt = prompt.to(self.device)
|
219 |
+
prompt_tokens = self.compression_model.encode(prompt)
|
220 |
+
else:
|
221 |
+
prompt_tokens = None
|
222 |
+
|
223 |
+
return attributes, prompt_tokens
|
224 |
+
|
SongBloom/models/transformer.py
ADDED
@@ -0,0 +1,937 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import reduce, partial
|
2 |
+
from packaging import version
|
3 |
+
|
4 |
+
from einops import rearrange, repeat
|
5 |
+
from einops.layers.torch import Rearrange
|
6 |
+
import torch
|
7 |
+
import torch.nn.functional as F
|
8 |
+
from torch import nn, einsum
|
9 |
+
from torch.cuda.amp import autocast
|
10 |
+
from typing import Callable, Literal
|
11 |
+
import os, sys
|
12 |
+
import warnings
|
13 |
+
from torch.utils import checkpoint
|
14 |
+
from transformers.utils import is_flash_attn_2_available
|
15 |
+
|
16 |
+
try:
|
17 |
+
assert is_flash_attn_2_available()
|
18 |
+
assert torch.cuda.get_device_capability(torch.device("cuda")) >= (8, 0)
|
19 |
+
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
20 |
+
from flash_attn.bert_padding import index_first_axis, unpad_input, pad_input
|
21 |
+
assert os.environ.get("DISABLE_FLASH_ATTN",'0') != "1"
|
22 |
+
except Exception as e:
|
23 |
+
flash_attn_kvpacked_func = None
|
24 |
+
flash_attn_func = None
|
25 |
+
warnings.warn("Not support flash-attn!")
|
26 |
+
|
27 |
+
try:
|
28 |
+
import natten
|
29 |
+
except ImportError:
|
30 |
+
natten = None
|
31 |
+
|
32 |
+
def checkpoint(function, *args, **kwargs):
|
33 |
+
kwargs.setdefault("use_reentrant", False)
|
34 |
+
return torch.utils.checkpoint.checkpoint(function, *args, **kwargs)
|
35 |
+
|
36 |
+
|
37 |
+
# Copied and modified from https://github.com/lucidrains/x-transformers/blob/main/x_transformers/attend.py under MIT License
|
38 |
+
# License can be found in LICENSES/LICENSE_XTRANSFORMERS.txt
|
39 |
+
|
40 |
+
def create_causal_mask(i, j, device):
|
41 |
+
return torch.ones((i, j), device = device, dtype = torch.bool).triu(j - i + 1)
|
42 |
+
|
43 |
+
def or_reduce(masks):
|
44 |
+
head, *body = masks
|
45 |
+
for rest in body:
|
46 |
+
head = head | rest
|
47 |
+
return head
|
48 |
+
|
49 |
+
# positional embeddings
|
50 |
+
|
51 |
+
class AbsolutePositionalEmbedding(nn.Module):
|
52 |
+
def __init__(self, dim, max_seq_len):
|
53 |
+
super().__init__()
|
54 |
+
self.scale = dim ** -0.5
|
55 |
+
self.max_seq_len = max_seq_len
|
56 |
+
self.emb = nn.Embedding(max_seq_len, dim)
|
57 |
+
|
58 |
+
def forward(self, x, pos = None, seq_start_pos = None):
|
59 |
+
seq_len, device = x.shape[1], x.device
|
60 |
+
assert seq_len <= self.max_seq_len, f'you are passing in a sequence length of {seq_len} but your absolute positional embedding has a max sequence length of {self.max_seq_len}'
|
61 |
+
|
62 |
+
if pos is None:
|
63 |
+
pos = torch.arange(seq_len, device = device)
|
64 |
+
|
65 |
+
if seq_start_pos is not None:
|
66 |
+
pos = (pos - seq_start_pos[..., None]).clamp(min = 0)
|
67 |
+
|
68 |
+
pos_emb = self.emb(pos)
|
69 |
+
pos_emb = pos_emb * self.scale
|
70 |
+
return pos_emb
|
71 |
+
|
72 |
+
class ScaledSinusoidalEmbedding(nn.Module):
|
73 |
+
def __init__(self, dim, theta = 10000):
|
74 |
+
super().__init__()
|
75 |
+
assert (dim % 2) == 0, 'dimension must be divisible by 2'
|
76 |
+
self.scale = nn.Parameter(torch.ones(1) * dim ** -0.5)
|
77 |
+
|
78 |
+
half_dim = dim // 2
|
79 |
+
freq_seq = torch.arange(half_dim).float() / half_dim
|
80 |
+
inv_freq = theta ** -freq_seq
|
81 |
+
self.register_buffer('inv_freq', inv_freq, persistent = False)
|
82 |
+
|
83 |
+
def forward(self, x, pos = None, seq_start_pos = None):
|
84 |
+
seq_len, device = x.shape[1], x.device
|
85 |
+
|
86 |
+
if pos is None:
|
87 |
+
pos = torch.arange(seq_len, device = device)
|
88 |
+
|
89 |
+
if seq_start_pos is not None:
|
90 |
+
pos = pos - seq_start_pos[..., None]
|
91 |
+
|
92 |
+
emb = einsum('i, j -> i j', pos, self.inv_freq)
|
93 |
+
emb = torch.cat((emb.sin(), emb.cos()), dim = -1)
|
94 |
+
return emb * self.scale
|
95 |
+
|
96 |
+
class RotaryEmbedding(nn.Module):
|
97 |
+
def __init__(
|
98 |
+
self,
|
99 |
+
dim,
|
100 |
+
use_xpos = False,
|
101 |
+
scale_base = 512,
|
102 |
+
interpolation_factor = 1.,
|
103 |
+
base = 10000,
|
104 |
+
base_rescale_factor = 1.
|
105 |
+
):
|
106 |
+
super().__init__()
|
107 |
+
# proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
|
108 |
+
# has some connection to NTK literature
|
109 |
+
# https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
|
110 |
+
base *= base_rescale_factor ** (dim / (dim - 2))
|
111 |
+
|
112 |
+
inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim))
|
113 |
+
self.register_buffer('inv_freq', inv_freq, persistent=False)
|
114 |
+
|
115 |
+
assert interpolation_factor >= 1.
|
116 |
+
self.interpolation_factor = interpolation_factor
|
117 |
+
|
118 |
+
if not use_xpos:
|
119 |
+
self.register_buffer('scale', None)
|
120 |
+
else:
|
121 |
+
scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)
|
122 |
+
|
123 |
+
self.scale_base = scale_base
|
124 |
+
self.register_buffer('scale', scale)
|
125 |
+
|
126 |
+
def forward_from_seq_len(self, seq_len):
|
127 |
+
device = self.inv_freq.device
|
128 |
+
|
129 |
+
t = torch.arange(seq_len, device = device)
|
130 |
+
return self.forward(t)
|
131 |
+
|
132 |
+
@autocast(enabled = False)
|
133 |
+
def forward(self, t):
|
134 |
+
device = self.inv_freq.device
|
135 |
+
|
136 |
+
t = t.to(torch.float32)
|
137 |
+
seq_len = t.shape[0]
|
138 |
+
|
139 |
+
t = t / self.interpolation_factor
|
140 |
+
|
141 |
+
freqs = torch.einsum('i , j -> i j', t, self.inv_freq)
|
142 |
+
freqs = torch.cat((freqs, freqs), dim = -1)
|
143 |
+
|
144 |
+
if self.scale is None:
|
145 |
+
return freqs, 1.
|
146 |
+
|
147 |
+
power = (torch.arange(seq_len, device = device) - (seq_len // 2)) / self.scale_base
|
148 |
+
scale = self.scale ** rearrange(power, 'n -> n 1')
|
149 |
+
scale = torch.cat((scale, scale), dim = -1)
|
150 |
+
|
151 |
+
return freqs, scale
|
152 |
+
|
153 |
+
class RotaryEmbedding2D(RotaryEmbedding):
|
154 |
+
def __init__(self, dim, w, **kwargs):
|
155 |
+
super().__init__(dim // 2, **kwargs)
|
156 |
+
self.w = w
|
157 |
+
|
158 |
+
|
159 |
+
def forward_from_seq_len(self, seq_len):
|
160 |
+
device = self.inv_freq.device
|
161 |
+
assert seq_len % self.w == 0 , f"{seq_len} % {self.w} != 0"
|
162 |
+
h_len = seq_len // self.w
|
163 |
+
|
164 |
+
t_h = torch.arange(h_len, device = device)
|
165 |
+
t_w = torch.arange(self.w, device = device)
|
166 |
+
|
167 |
+
return self.forward(t_h, t_w)
|
168 |
+
|
169 |
+
@autocast(enabled = False)
|
170 |
+
def forward(self, t_h: torch.Tensor, t_w: torch.Tensor):
|
171 |
+
repeat_t_h = t_h.repeat_interleave(t_w.shape[0], dim=0)
|
172 |
+
repeat_t_w = t_w.repeat(t_h.shape[0])
|
173 |
+
freq_h, scale_h = super().forward(repeat_t_h)
|
174 |
+
freq_w, scale_w = super().forward(repeat_t_w)
|
175 |
+
freq = torch.stack([freq_h, freq_w], dim=-1) #h*w, D//2, 2
|
176 |
+
freq = torch.cat(torch.unbind(freq, dim=-2), dim=-1)
|
177 |
+
|
178 |
+
if self.scale is None:
|
179 |
+
scale = 1.
|
180 |
+
else:
|
181 |
+
scale = torch.stack([scale_h, scale_w], dim=-1)
|
182 |
+
scale = torch.cat(torch.unbind(scale, dim=-2), dim=-1)
|
183 |
+
|
184 |
+
return freq, scale
|
185 |
+
|
186 |
+
|
187 |
+
|
188 |
+
|
189 |
+
def rotate_half(x):
|
190 |
+
x = rearrange(x, '... (j d) -> ... j d', j = 2)
|
191 |
+
x1, x2 = x.unbind(dim = -2)
|
192 |
+
return torch.cat((-x2, x1), dim = -1)
|
193 |
+
|
194 |
+
@autocast(enabled = False)
|
195 |
+
def apply_rotary_pos_emb(t, freqs, scale = 1):
|
196 |
+
out_dtype = t.dtype
|
197 |
+
|
198 |
+
# cast to float32 if necessary for numerical stability
|
199 |
+
dtype = reduce(torch.promote_types, (t.dtype, freqs.dtype, torch.float32))
|
200 |
+
rot_dim, seq_len = freqs.shape[-1], t.shape[-2]
|
201 |
+
freqs, t = freqs.to(dtype), t.to(dtype)
|
202 |
+
freqs = freqs[-seq_len:, :]
|
203 |
+
|
204 |
+
if t.ndim == 4 and freqs.ndim == 3:
|
205 |
+
freqs = rearrange(freqs, 'b n d -> b 1 n d')
|
206 |
+
|
207 |
+
# partial rotary embeddings, Wang et al. GPT-J
|
208 |
+
t, t_unrotated = t[..., :rot_dim], t[..., rot_dim:]
|
209 |
+
t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale)
|
210 |
+
|
211 |
+
t, t_unrotated = t.to(out_dtype), t_unrotated.to(out_dtype)
|
212 |
+
return torch.cat((t, t_unrotated), dim = -1)
|
213 |
+
|
214 |
+
# norms
|
215 |
+
class LayerNorm(nn.Module):
|
216 |
+
def __init__(self, dim, bias=False, fix_scale=False):
|
217 |
+
"""
|
218 |
+
bias-less layernorm has been shown to be more stable. most newer models have moved towards rmsnorm, also bias-less
|
219 |
+
"""
|
220 |
+
super().__init__()
|
221 |
+
|
222 |
+
if fix_scale:
|
223 |
+
self.register_buffer("gamma", torch.ones(dim))
|
224 |
+
else:
|
225 |
+
self.gamma = nn.Parameter(torch.ones(dim))
|
226 |
+
|
227 |
+
if bias:
|
228 |
+
self.beta = nn.Parameter(torch.zeros(dim))
|
229 |
+
else:
|
230 |
+
self.register_buffer("beta", torch.zeros(dim))
|
231 |
+
|
232 |
+
|
233 |
+
def forward(self, x):
|
234 |
+
return F.layer_norm(x, x.shape[-1:], weight=self.gamma, bias=self.beta)
|
235 |
+
|
236 |
+
# feedforward
|
237 |
+
|
238 |
+
class GLU(nn.Module):
|
239 |
+
def __init__(
|
240 |
+
self,
|
241 |
+
dim_in,
|
242 |
+
dim_out,
|
243 |
+
activation: Callable,
|
244 |
+
use_conv = False,
|
245 |
+
conv_kernel_size = 3,
|
246 |
+
bias = False,
|
247 |
+
):
|
248 |
+
super().__init__()
|
249 |
+
self.act = activation
|
250 |
+
self.up_proj = nn.Linear(dim_in, dim_out, bias=bias) if not use_conv else nn.Conv1d(dim_in, dim_out, conv_kernel_size, padding = (conv_kernel_size // 2))
|
251 |
+
self.gate_proj = nn.Linear(dim_in, dim_out, bias=bias) if not use_conv else nn.Conv1d(dim_in, dim_out, conv_kernel_size, padding = (conv_kernel_size // 2))
|
252 |
+
self.use_conv = use_conv
|
253 |
+
|
254 |
+
def forward(self, x):
|
255 |
+
if self.use_conv:
|
256 |
+
x = rearrange(x, 'b n d -> b d n')
|
257 |
+
gate = self.gate_proj(x)
|
258 |
+
x = self.up_proj(x)
|
259 |
+
x = rearrange(x, 'b d n -> b n d')
|
260 |
+
gate = rearrange(gate, 'b d n -> b n d')
|
261 |
+
else:
|
262 |
+
gate = self.gate_proj(x)
|
263 |
+
x = self.up_proj(x)
|
264 |
+
|
265 |
+
return x * self.act(gate)
|
266 |
+
|
267 |
+
class FeedForward(nn.Module):
|
268 |
+
def __init__(
|
269 |
+
self,
|
270 |
+
dim,
|
271 |
+
dim_out = None,
|
272 |
+
dim_ff = None,
|
273 |
+
no_bias = False,
|
274 |
+
glu = True,
|
275 |
+
use_conv = False,
|
276 |
+
conv_kernel_size = 3,
|
277 |
+
zero_init_output = True,
|
278 |
+
):
|
279 |
+
super().__init__()
|
280 |
+
inner_dim = dim_ff if dim_ff is not None else 4 * dim
|
281 |
+
|
282 |
+
# Default to SwiGLU
|
283 |
+
|
284 |
+
activation = nn.SiLU()
|
285 |
+
|
286 |
+
dim_out = dim if dim_out is None else dim_out
|
287 |
+
|
288 |
+
if glu:
|
289 |
+
linear_in = GLU(dim, inner_dim, activation, bias=not no_bias)
|
290 |
+
else:
|
291 |
+
linear_in = nn.Sequential(
|
292 |
+
Rearrange('b n d -> b d n') if use_conv else nn.Identity(),
|
293 |
+
nn.Linear(dim, inner_dim, bias = not no_bias) if not use_conv else nn.Conv1d(dim, inner_dim, conv_kernel_size, padding = (conv_kernel_size // 2), bias = not no_bias),
|
294 |
+
Rearrange('b n d -> b d n') if use_conv else nn.Identity(),
|
295 |
+
activation
|
296 |
+
)
|
297 |
+
|
298 |
+
linear_out = nn.Linear(inner_dim, dim_out, bias = not no_bias) if not use_conv else nn.Conv1d(inner_dim, dim_out, conv_kernel_size, padding = (conv_kernel_size // 2), bias = not no_bias)
|
299 |
+
|
300 |
+
# init last linear layer to 0
|
301 |
+
if zero_init_output:
|
302 |
+
nn.init.zeros_(linear_out.weight)
|
303 |
+
if not no_bias:
|
304 |
+
nn.init.zeros_(linear_out.bias)
|
305 |
+
|
306 |
+
|
307 |
+
self.ff = nn.Sequential(
|
308 |
+
linear_in,
|
309 |
+
Rearrange('b d n -> b n d') if use_conv else nn.Identity(),
|
310 |
+
linear_out,
|
311 |
+
Rearrange('b n d -> b d n') if use_conv else nn.Identity(),
|
312 |
+
)
|
313 |
+
|
314 |
+
def forward(self, x):
|
315 |
+
return self.ff(x)
|
316 |
+
|
317 |
+
class Attention(nn.Module):
|
318 |
+
def __init__(
|
319 |
+
self,
|
320 |
+
dim,
|
321 |
+
dim_heads = 64,
|
322 |
+
dim_context = None,
|
323 |
+
causal = False,
|
324 |
+
zero_init_output=True,
|
325 |
+
qk_norm: Literal['l2', 'ln', 'none'] = 'none',
|
326 |
+
natten_kernel_size = None
|
327 |
+
):
|
328 |
+
super().__init__()
|
329 |
+
self.dim = dim
|
330 |
+
self.dim_heads = dim_heads
|
331 |
+
self.causal = causal
|
332 |
+
|
333 |
+
dim_kv = dim_context if dim_context is not None else dim
|
334 |
+
|
335 |
+
self.num_heads = dim // dim_heads
|
336 |
+
self.kv_heads = dim_kv // dim_heads
|
337 |
+
|
338 |
+
if dim_context is not None:
|
339 |
+
self.to_q = nn.Linear(dim, dim, bias=False)
|
340 |
+
self.to_kv = nn.Linear(dim_kv, dim_kv * 2, bias=False)
|
341 |
+
else:
|
342 |
+
self.to_qkv = nn.Linear(dim, dim * 3, bias=False)
|
343 |
+
|
344 |
+
self.to_out = nn.Linear(dim, dim, bias=False)
|
345 |
+
|
346 |
+
if zero_init_output:
|
347 |
+
nn.init.zeros_(self.to_out.weight)
|
348 |
+
|
349 |
+
self.qk_norm = qk_norm
|
350 |
+
|
351 |
+
if self.qk_norm == "ln":
|
352 |
+
self.q_norm = nn.LayerNorm(dim_heads, elementwise_affine=True, eps=1.0e-6)
|
353 |
+
self.k_norm = nn.LayerNorm(dim_heads, elementwise_affine=True, eps=1.0e-6)
|
354 |
+
|
355 |
+
# Using 1d neighborhood attention
|
356 |
+
self.natten_kernel_size = natten_kernel_size
|
357 |
+
if natten_kernel_size is not None:
|
358 |
+
return
|
359 |
+
|
360 |
+
self.use_pt_flash = torch.cuda.is_available() and version.parse(torch.__version__) >= version.parse('2.0.0')
|
361 |
+
|
362 |
+
self.use_fa_flash = torch.cuda.is_available() and flash_attn_func is not None
|
363 |
+
|
364 |
+
self.sdp_kwargs = dict(
|
365 |
+
enable_flash = True,
|
366 |
+
enable_math = True,
|
367 |
+
enable_mem_efficient = True
|
368 |
+
)
|
369 |
+
|
370 |
+
def flash_attn(
|
371 |
+
self,
|
372 |
+
q,
|
373 |
+
k,
|
374 |
+
v,
|
375 |
+
mask = None,
|
376 |
+
causal = None
|
377 |
+
):
|
378 |
+
batch, heads, q_len, _, k_len, device = *q.shape, k.shape[-2], q.device
|
379 |
+
kv_heads = k.shape[1]
|
380 |
+
# Recommended for multi-query single-key-value attention by Tri Dao
|
381 |
+
# kv shape torch.Size([1, 512, 64]) -> torch.Size([1, 8, 512, 64])
|
382 |
+
|
383 |
+
if heads != kv_heads:
|
384 |
+
# Repeat interleave kv_heads to match q_heads
|
385 |
+
heads_per_kv_head = heads // kv_heads
|
386 |
+
k, v = map(lambda t: t.repeat_interleave(heads_per_kv_head, dim = 1), (k, v))
|
387 |
+
|
388 |
+
if k.ndim == 3:
|
389 |
+
k = rearrange(k, 'b ... -> b 1 ...').expand_as(q)
|
390 |
+
|
391 |
+
if v.ndim == 3:
|
392 |
+
v = rearrange(v, 'b ... -> b 1 ...').expand_as(q)
|
393 |
+
|
394 |
+
causal = self.causal if causal is None else causal
|
395 |
+
|
396 |
+
if q_len == 1 and causal:
|
397 |
+
causal = False
|
398 |
+
|
399 |
+
if mask is not None:
|
400 |
+
assert mask.ndim == 4
|
401 |
+
mask = mask.expand(batch, heads, q_len, k_len)
|
402 |
+
|
403 |
+
# handle kv cache - this should be bypassable in updated flash attention 2
|
404 |
+
|
405 |
+
if k_len > q_len and causal:
|
406 |
+
causal_mask = self.create_causal_mask(q_len, k_len, device = device)
|
407 |
+
if mask is None:
|
408 |
+
mask = ~causal_mask
|
409 |
+
else:
|
410 |
+
mask = mask & ~causal_mask
|
411 |
+
causal = False
|
412 |
+
|
413 |
+
# manually handle causal mask, if another mask was given
|
414 |
+
|
415 |
+
row_is_entirely_masked = None
|
416 |
+
|
417 |
+
if mask is not None and causal:
|
418 |
+
causal_mask = self.create_causal_mask(q_len, k_len, device = device)
|
419 |
+
mask = mask & ~causal_mask
|
420 |
+
|
421 |
+
# protect against an entire row being masked out
|
422 |
+
|
423 |
+
row_is_entirely_masked = ~mask.any(dim = -1)
|
424 |
+
mask[..., 0] = mask[..., 0] | row_is_entirely_masked
|
425 |
+
|
426 |
+
causal = False
|
427 |
+
|
428 |
+
with torch.backends.cuda.sdp_kernel(**self.sdp_kwargs):
|
429 |
+
out = F.scaled_dot_product_attention(
|
430 |
+
q, k, v,
|
431 |
+
attn_mask = mask,
|
432 |
+
is_causal = causal
|
433 |
+
)
|
434 |
+
|
435 |
+
# for a row that is entirely masked out, should zero out the output of that row token
|
436 |
+
|
437 |
+
if row_is_entirely_masked is not None:
|
438 |
+
out = out.masked_fill(row_is_entirely_masked[..., None], 0.)
|
439 |
+
|
440 |
+
return out
|
441 |
+
|
442 |
+
def forward(
|
443 |
+
self,
|
444 |
+
x,
|
445 |
+
context = None,
|
446 |
+
mask = None,
|
447 |
+
context_mask = None,
|
448 |
+
rotary_pos_emb = None,
|
449 |
+
causal = None
|
450 |
+
):
|
451 |
+
h, kv_h, has_context = self.num_heads, self.kv_heads, context is not None
|
452 |
+
|
453 |
+
kv_input = context if has_context else x
|
454 |
+
|
455 |
+
if hasattr(self, 'to_q'):
|
456 |
+
# Use separate linear projections for q and k/v
|
457 |
+
q = self.to_q(x)
|
458 |
+
q = rearrange(q, 'b n (h d) -> b h n d', h = h)
|
459 |
+
|
460 |
+
k, v = self.to_kv(kv_input).chunk(2, dim=-1)
|
461 |
+
|
462 |
+
k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = kv_h), (k, v))
|
463 |
+
else:
|
464 |
+
# Use fused linear projection
|
465 |
+
q, k, v = self.to_qkv(x).chunk(3, dim=-1)
|
466 |
+
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
|
467 |
+
|
468 |
+
# Normalize q and k for cosine sim attention
|
469 |
+
if self.qk_norm == "l2":
|
470 |
+
q = F.normalize(q, dim=-1)
|
471 |
+
k = F.normalize(k, dim=-1)
|
472 |
+
elif self.qk_norm == "ln":
|
473 |
+
q = self.q_norm(q)
|
474 |
+
k = self.k_norm(k)
|
475 |
+
|
476 |
+
if rotary_pos_emb is not None and not has_context:
|
477 |
+
freqs, _ = rotary_pos_emb
|
478 |
+
|
479 |
+
q_dtype = q.dtype
|
480 |
+
k_dtype = k.dtype
|
481 |
+
|
482 |
+
q = q.to(torch.float32)
|
483 |
+
k = k.to(torch.float32)
|
484 |
+
freqs = freqs.to(torch.float32)
|
485 |
+
|
486 |
+
q = apply_rotary_pos_emb(q, freqs)
|
487 |
+
k = apply_rotary_pos_emb(k, freqs)
|
488 |
+
|
489 |
+
q = q.to(q_dtype)
|
490 |
+
k = k.to(k_dtype)
|
491 |
+
|
492 |
+
# TODO 这里这俩都是 [B, k/Q_len]这样的格式
|
493 |
+
# context mask也许应该改成 [B, Q_len, K_len]
|
494 |
+
# 并且下面flash_attn 默认假设attn靠左部分全为1
|
495 |
+
input_mask = context_mask # cross-attn
|
496 |
+
if input_mask is None and not has_context: # self-attn
|
497 |
+
input_mask = mask
|
498 |
+
|
499 |
+
# determine masking
|
500 |
+
masks = []
|
501 |
+
final_attn_mask = None # The mask that will be applied to the attention matrix, taking all masks into account
|
502 |
+
|
503 |
+
if input_mask is not None:
|
504 |
+
input_mask = rearrange(input_mask, 'b j -> b 1 1 j')
|
505 |
+
masks.append(~input_mask)
|
506 |
+
|
507 |
+
# Other masks will be added here later
|
508 |
+
|
509 |
+
if len(masks) > 0:
|
510 |
+
final_attn_mask = ~or_reduce(masks)
|
511 |
+
|
512 |
+
n, device = q.shape[-2], q.device
|
513 |
+
|
514 |
+
causal = self.causal if causal is None else causal
|
515 |
+
if n == 1 and causal:
|
516 |
+
causal = False
|
517 |
+
if self.natten_kernel_size is not None:
|
518 |
+
if natten is None:
|
519 |
+
raise ImportError('natten not installed, please install natten to use neighborhood attention')
|
520 |
+
|
521 |
+
dtype_in = q.dtype
|
522 |
+
q, k, v = map(lambda t: t.to(torch.float32), (q, k, v))
|
523 |
+
|
524 |
+
attn = natten.functional.natten1dqk(q, k, kernel_size = self.natten_kernel_size, dilation=1)
|
525 |
+
|
526 |
+
if final_attn_mask is not None:
|
527 |
+
attn = attn.masked_fill(final_attn_mask, -torch.finfo(attn.dtype).max)
|
528 |
+
|
529 |
+
attn = F.softmax(attn, dim=-1, dtype=torch.float32)
|
530 |
+
|
531 |
+
out = natten.functional.natten1dav(attn, v, kernel_size = self.natten_kernel_size, dilation=1).to(dtype_in)
|
532 |
+
|
533 |
+
# Prioritize Flash Attention 2
|
534 |
+
elif self.use_fa_flash:
|
535 |
+
fa_dtype_in = q.dtype
|
536 |
+
if q.dtype in [torch.float, torch.float32]:
|
537 |
+
target_dtype = self.to_out.weight.dtype if self.to_out.weight.dtype not in [torch.float, torch.float32] else torch.float16
|
538 |
+
warnings.warn(
|
539 |
+
f"The input hidden states seems to be silently casted in float32, this might be related to"
|
540 |
+
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
|
541 |
+
f" {target_dtype}."
|
542 |
+
)
|
543 |
+
q, k, v = map(lambda t: t.to(target_dtype), (q, k, v))
|
544 |
+
q, k, v = map(lambda t: rearrange(t, 'b h n d -> b n h d'), (q, k, v))
|
545 |
+
# out = flash_attn_func(q, k, v, causal = causal)
|
546 |
+
if final_attn_mask is not None:
|
547 |
+
# Check if the mask meets the requirement of FlashAttn
|
548 |
+
kv_seq_mask = final_attn_mask.squeeze(dim=[1,2])
|
549 |
+
kv_reallens = kv_seq_mask.sum(dim=-1, dtype=torch.int32)
|
550 |
+
first_zero_indices = torch.argmax((kv_seq_mask == 0).int(), dim=1).masked_fill(kv_seq_mask[:,-1] != 0, kv_seq_mask.shape[1])
|
551 |
+
assert (kv_reallens == first_zero_indices).all(), f'{kv_reallens} , {first_zero_indices}'
|
552 |
+
|
553 |
+
batch_size, kv_seq_len, num_key_value_heads, head_dim = k.shape
|
554 |
+
unpad_k, indices_k, cu_seqlens_k, max_seqlen_in_batch_k = unpad_input(k, kv_seq_mask)
|
555 |
+
unpad_v = index_first_axis(
|
556 |
+
v.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
|
557 |
+
)
|
558 |
+
q_seq_len = q.shape[1]
|
559 |
+
unpad_q, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(q, torch.ones((batch_size, q_seq_len), device=q.device, dtype=torch.bool))
|
560 |
+
# print(q.shape, k.shape)
|
561 |
+
# print(cu_seqlens_q, cu_seqlens_k)
|
562 |
+
# breakpoint()
|
563 |
+
out_unpad = flash_attn_varlen_func(
|
564 |
+
unpad_q,
|
565 |
+
unpad_k,
|
566 |
+
unpad_v,
|
567 |
+
cu_seqlens_q=cu_seqlens_q,
|
568 |
+
cu_seqlens_k=cu_seqlens_k,
|
569 |
+
max_seqlen_q=max_seqlen_in_batch_q,
|
570 |
+
max_seqlen_k=max_seqlen_in_batch_k,
|
571 |
+
causal=causal,
|
572 |
+
)
|
573 |
+
out = pad_input(out_unpad, indices_q, batch_size, q_seq_len)
|
574 |
+
else:
|
575 |
+
out = flash_attn_func(q, k, v, causal = causal)
|
576 |
+
|
577 |
+
|
578 |
+
out = rearrange(out.to(fa_dtype_in), 'b n h d -> b h n d')
|
579 |
+
# Fall back to PyTorch implementation
|
580 |
+
elif self.use_pt_flash:
|
581 |
+
out = self.flash_attn(q, k, v, causal = causal, mask = final_attn_mask)
|
582 |
+
|
583 |
+
else:
|
584 |
+
# Fall back to custom implementation
|
585 |
+
|
586 |
+
if h != kv_h:
|
587 |
+
# Repeat interleave kv_heads to match q_heads
|
588 |
+
heads_per_kv_head = h // kv_h
|
589 |
+
k, v = map(lambda t: t.repeat_interleave(heads_per_kv_head, dim = 1), (k, v))
|
590 |
+
|
591 |
+
scale = 1. / (q.shape[-1] ** 0.5)
|
592 |
+
|
593 |
+
kv_einsum_eq = 'b j d' if k.ndim == 3 else 'b h j d'
|
594 |
+
|
595 |
+
dots = einsum(f'b h i d, {kv_einsum_eq} -> b h i j', q, k) * scale
|
596 |
+
|
597 |
+
i, j, dtype = *dots.shape[-2:], dots.dtype
|
598 |
+
|
599 |
+
mask_value = -torch.finfo(dots.dtype).max
|
600 |
+
|
601 |
+
if final_attn_mask is not None:
|
602 |
+
dots = dots.masked_fill(~final_attn_mask, mask_value)
|
603 |
+
|
604 |
+
if causal:
|
605 |
+
causal_mask = self.create_causal_mask(i, j, device = device)
|
606 |
+
dots = dots.masked_fill(causal_mask, mask_value)
|
607 |
+
|
608 |
+
attn = F.softmax(dots, dim=-1, dtype=torch.float32)
|
609 |
+
attn = attn.type(dtype)
|
610 |
+
|
611 |
+
out = einsum(f'b h i j, {kv_einsum_eq} -> b h i d', attn, v)
|
612 |
+
|
613 |
+
# merge heads
|
614 |
+
out = rearrange(out, ' b h n d -> b n (h d)')
|
615 |
+
|
616 |
+
# Communicate between heads
|
617 |
+
|
618 |
+
# with autocast(enabled = False):
|
619 |
+
# out_dtype = out.dtype
|
620 |
+
# out = out.to(torch.float32)
|
621 |
+
# out = self.to_out(out).to(out_dtype)
|
622 |
+
out = self.to_out(out)
|
623 |
+
|
624 |
+
if mask is not None:
|
625 |
+
mask = rearrange(mask, 'b n -> b n 1')
|
626 |
+
out = out.masked_fill(~mask, 0.)
|
627 |
+
|
628 |
+
return out
|
629 |
+
|
630 |
+
class ConformerModule(nn.Module):
|
631 |
+
def __init__(
|
632 |
+
self,
|
633 |
+
dim,
|
634 |
+
norm_kwargs = {},
|
635 |
+
):
|
636 |
+
|
637 |
+
super().__init__()
|
638 |
+
|
639 |
+
self.dim = dim
|
640 |
+
|
641 |
+
self.in_norm = LayerNorm(dim, **norm_kwargs)
|
642 |
+
self.pointwise_conv = nn.Conv1d(dim, dim, kernel_size=1, bias=False)
|
643 |
+
self.glu = GLU(dim, dim, nn.SiLU())
|
644 |
+
self.depthwise_conv = nn.Conv1d(dim, dim, kernel_size=17, groups=dim, padding=8, bias=False)
|
645 |
+
self.mid_norm = LayerNorm(dim, **norm_kwargs) # This is a batch norm in the original but I don't like batch norm
|
646 |
+
self.swish = nn.SiLU()
|
647 |
+
self.pointwise_conv_2 = nn.Conv1d(dim, dim, kernel_size=1, bias=False)
|
648 |
+
|
649 |
+
def forward(self, x):
|
650 |
+
x = self.in_norm(x)
|
651 |
+
x = rearrange(x, 'b n d -> b d n')
|
652 |
+
x = self.pointwise_conv(x)
|
653 |
+
x = rearrange(x, 'b d n -> b n d')
|
654 |
+
x = self.glu(x)
|
655 |
+
x = rearrange(x, 'b n d -> b d n')
|
656 |
+
x = self.depthwise_conv(x)
|
657 |
+
x = rearrange(x, 'b d n -> b n d')
|
658 |
+
x = self.mid_norm(x)
|
659 |
+
x = self.swish(x)
|
660 |
+
x = rearrange(x, 'b n d -> b d n')
|
661 |
+
x = self.pointwise_conv_2(x)
|
662 |
+
x = rearrange(x, 'b d n -> b n d')
|
663 |
+
|
664 |
+
return x
|
665 |
+
|
666 |
+
class TransformerBlock(nn.Module):
|
667 |
+
def __init__(
|
668 |
+
self,
|
669 |
+
dim,
|
670 |
+
dim_heads = 64,
|
671 |
+
cross_attend = False,
|
672 |
+
dim_context = None,
|
673 |
+
global_cond_dim = None,
|
674 |
+
causal = False,
|
675 |
+
zero_init_branch_outputs = True,
|
676 |
+
conformer = False,
|
677 |
+
layer_ix = -1,
|
678 |
+
remove_norms = False,
|
679 |
+
attn_kwargs = {},
|
680 |
+
ff_kwargs = {},
|
681 |
+
norm_kwargs = {}
|
682 |
+
):
|
683 |
+
|
684 |
+
super().__init__()
|
685 |
+
self.dim = dim
|
686 |
+
self.dim_heads = dim_heads
|
687 |
+
self.cross_attend = cross_attend
|
688 |
+
self.dim_context = dim_context
|
689 |
+
self.causal = causal
|
690 |
+
|
691 |
+
self.pre_norm = LayerNorm(dim, **norm_kwargs) if not remove_norms else nn.Identity()
|
692 |
+
|
693 |
+
self.self_attn = Attention(
|
694 |
+
dim,
|
695 |
+
dim_heads = dim_heads,
|
696 |
+
causal = causal,
|
697 |
+
zero_init_output=zero_init_branch_outputs,
|
698 |
+
**attn_kwargs
|
699 |
+
)
|
700 |
+
|
701 |
+
if cross_attend:
|
702 |
+
self.cross_attend_norm = LayerNorm(dim, **norm_kwargs) if not remove_norms else nn.Identity()
|
703 |
+
self.cross_attn = Attention(
|
704 |
+
dim,
|
705 |
+
dim_heads = dim_heads,
|
706 |
+
dim_context=dim_context,
|
707 |
+
causal = causal,
|
708 |
+
zero_init_output=zero_init_branch_outputs,
|
709 |
+
**attn_kwargs
|
710 |
+
)
|
711 |
+
|
712 |
+
self.ff_norm = LayerNorm(dim, **norm_kwargs) if not remove_norms else nn.Identity()
|
713 |
+
self.ff = FeedForward(dim, zero_init_output=zero_init_branch_outputs, **ff_kwargs)
|
714 |
+
|
715 |
+
self.layer_ix = layer_ix
|
716 |
+
|
717 |
+
self.conformer = ConformerModule(dim, norm_kwargs=norm_kwargs) if conformer else None
|
718 |
+
|
719 |
+
self.global_cond_dim = global_cond_dim
|
720 |
+
|
721 |
+
if global_cond_dim is not None:
|
722 |
+
self.to_scale_shift_gate = nn.Sequential(
|
723 |
+
nn.SiLU(),
|
724 |
+
nn.Linear(global_cond_dim, dim * 6, bias=False)
|
725 |
+
)
|
726 |
+
|
727 |
+
nn.init.zeros_(self.to_scale_shift_gate[1].weight)
|
728 |
+
#nn.init.zeros_(self.to_scale_shift_gate_self[1].bias)
|
729 |
+
|
730 |
+
def forward(
|
731 |
+
self,
|
732 |
+
x,
|
733 |
+
mask = None,
|
734 |
+
global_cond=None,
|
735 |
+
context = None,
|
736 |
+
context_mask = None,
|
737 |
+
rotary_pos_emb = None
|
738 |
+
):
|
739 |
+
if self.global_cond_dim is not None:
|
740 |
+
assert global_cond is not None
|
741 |
+
# scale_self, shift_self, gate_self, scale_ff, shift_ff, gate_ff = checkpoint(self.to_scale_shift_gate, global_cond).unsqueeze(1).chunk(6, dim = -1)
|
742 |
+
scale_self, shift_self, gate_self, scale_ff, shift_ff, gate_ff = self.to_scale_shift_gate(global_cond).unsqueeze(1).chunk(6, dim = -1)
|
743 |
+
|
744 |
+
# self-attention with adaLN
|
745 |
+
residual = x
|
746 |
+
x = self.pre_norm(x)
|
747 |
+
x = x * (1 + scale_self) + shift_self
|
748 |
+
x = self.self_attn(x, mask = mask, rotary_pos_emb = rotary_pos_emb)
|
749 |
+
x = x * torch.sigmoid(1 - gate_self)
|
750 |
+
x = x + residual
|
751 |
+
|
752 |
+
if context is not None:
|
753 |
+
x = x + self.cross_attn(self.cross_attend_norm(x), context = context, context_mask = context_mask)
|
754 |
+
|
755 |
+
if self.conformer is not None:
|
756 |
+
x = x + self.conformer(x)
|
757 |
+
|
758 |
+
# feedforward with adaLN
|
759 |
+
residual = x
|
760 |
+
x = self.ff_norm(x)
|
761 |
+
x = x * (1 + scale_ff) + shift_ff
|
762 |
+
x = self.ff(x)
|
763 |
+
x = x * torch.sigmoid(1 - gate_ff)
|
764 |
+
x = x + residual
|
765 |
+
|
766 |
+
else:
|
767 |
+
assert global_cond is None
|
768 |
+
x = x + self.self_attn(self.pre_norm(x), mask = mask, rotary_pos_emb = rotary_pos_emb)
|
769 |
+
|
770 |
+
if context is not None:
|
771 |
+
x = x + self.cross_attn(self.cross_attend_norm(x), context = context, context_mask = context_mask)
|
772 |
+
|
773 |
+
if self.conformer is not None:
|
774 |
+
x = x + self.conformer(x)
|
775 |
+
|
776 |
+
x = x + self.ff(self.ff_norm(x))
|
777 |
+
|
778 |
+
return x
|
779 |
+
|
780 |
+
class ContinuousTransformer(nn.Module):
|
781 |
+
def __init__(
|
782 |
+
self,
|
783 |
+
dim,
|
784 |
+
depth,
|
785 |
+
*,
|
786 |
+
dim_in = None,
|
787 |
+
dim_out = None,
|
788 |
+
dim_heads = 64,
|
789 |
+
cross_attend=False,
|
790 |
+
cross_atten_layer_idx=None,
|
791 |
+
cond_token_dim=None,
|
792 |
+
global_cond_dim=None,
|
793 |
+
causal=False,
|
794 |
+
rotary_pos_emb=True,
|
795 |
+
zero_init_branch_outputs=True,
|
796 |
+
conformer=False,
|
797 |
+
use_sinusoidal_emb=False,
|
798 |
+
use_abs_pos_emb=False,
|
799 |
+
abs_pos_emb_max_length=10000,
|
800 |
+
pos_emb_2d_size=1,
|
801 |
+
rotary_base_val=10000,
|
802 |
+
init_std=0.02,
|
803 |
+
**kwargs
|
804 |
+
):
|
805 |
+
|
806 |
+
super().__init__()
|
807 |
+
|
808 |
+
self.dim = dim
|
809 |
+
self.depth = depth
|
810 |
+
self.causal = causal
|
811 |
+
self.layers = nn.ModuleList([])
|
812 |
+
|
813 |
+
self.project_in = nn.Linear(dim_in, dim, bias=False) if dim_in is not None else nn.Identity()
|
814 |
+
self.project_out = nn.Linear(dim, dim_out, bias=False) if dim_out is not None else nn.Identity()
|
815 |
+
|
816 |
+
if rotary_pos_emb:
|
817 |
+
if pos_emb_2d_size == 1:
|
818 |
+
self.rotary_pos_emb = RotaryEmbedding(max(dim_heads // 2, 32), base=rotary_base_val)
|
819 |
+
else:
|
820 |
+
self.rotary_pos_emb = RotaryEmbedding2D(max(dim_heads // 2, 32), pos_emb_2d_size, base=rotary_base_val)
|
821 |
+
else:
|
822 |
+
self.rotary_pos_emb = None
|
823 |
+
|
824 |
+
self.use_sinusoidal_emb = use_sinusoidal_emb
|
825 |
+
if use_sinusoidal_emb:
|
826 |
+
if pos_emb_2d_size != 1:
|
827 |
+
raise NotImplementedError
|
828 |
+
self.pos_emb = ScaledSinusoidalEmbedding(dim)
|
829 |
+
|
830 |
+
self.use_abs_pos_emb = use_abs_pos_emb
|
831 |
+
if use_abs_pos_emb:
|
832 |
+
if pos_emb_2d_size != 1:
|
833 |
+
raise NotImplementedError
|
834 |
+
self.pos_emb = AbsolutePositionalEmbedding(dim, abs_pos_emb_max_length)
|
835 |
+
|
836 |
+
|
837 |
+
if cross_atten_layer_idx is None:
|
838 |
+
cross_atten_layer_idx = list(range(depth))
|
839 |
+
for i in range(depth):
|
840 |
+
self.layers.append(
|
841 |
+
TransformerBlock(
|
842 |
+
dim,
|
843 |
+
dim_heads = dim_heads,
|
844 |
+
cross_attend = cross_attend and (i in cross_atten_layer_idx),
|
845 |
+
dim_context = cond_token_dim,
|
846 |
+
global_cond_dim = global_cond_dim,
|
847 |
+
causal = causal,
|
848 |
+
zero_init_branch_outputs = zero_init_branch_outputs,
|
849 |
+
conformer=conformer,
|
850 |
+
layer_ix=i,
|
851 |
+
**kwargs
|
852 |
+
)
|
853 |
+
)
|
854 |
+
self.gradient_checkpointing = False
|
855 |
+
|
856 |
+
self.apply(partial(self._init_weights,init_std=init_std))
|
857 |
+
|
858 |
+
def forward(
|
859 |
+
self,
|
860 |
+
x,
|
861 |
+
mask = None,
|
862 |
+
prepend_embeds = None,
|
863 |
+
prepend_mask = None,
|
864 |
+
global_cond = None,
|
865 |
+
return_info = False,
|
866 |
+
**kwargs
|
867 |
+
):
|
868 |
+
batch, seq, device = *x.shape[:2], x.device
|
869 |
+
|
870 |
+
info = {
|
871 |
+
"hidden_states": [],
|
872 |
+
}
|
873 |
+
|
874 |
+
x = self.project_in(x)
|
875 |
+
|
876 |
+
if prepend_embeds is not None:
|
877 |
+
prepend_length, prepend_dim = prepend_embeds.shape[1:]
|
878 |
+
|
879 |
+
assert prepend_dim == x.shape[-1], 'prepend dimension must match sequence dimension'
|
880 |
+
|
881 |
+
x = torch.cat((prepend_embeds, x), dim = -2)
|
882 |
+
|
883 |
+
if prepend_mask is not None or mask is not None:
|
884 |
+
mask = mask if mask is not None else torch.ones((batch, seq), device = device, dtype = torch.bool)
|
885 |
+
prepend_mask = prepend_mask if prepend_mask is not None else torch.ones((batch, prepend_length), device = device, dtype = torch.bool)
|
886 |
+
|
887 |
+
mask = torch.cat((prepend_mask, mask), dim = -1)
|
888 |
+
|
889 |
+
# Attention layers
|
890 |
+
|
891 |
+
if self.rotary_pos_emb is not None:
|
892 |
+
rotary_pos_emb = self.rotary_pos_emb.forward_from_seq_len(x.shape[1])
|
893 |
+
else:
|
894 |
+
rotary_pos_emb = None
|
895 |
+
|
896 |
+
if self.use_sinusoidal_emb or self.use_abs_pos_emb:
|
897 |
+
x = x + self.pos_emb(x)
|
898 |
+
|
899 |
+
# Iterate over the transformer layers
|
900 |
+
context, context_mask = kwargs.pop('context', None), kwargs.pop("context_mask", None)
|
901 |
+
|
902 |
+
for layer_idx, layer in enumerate(self.layers):
|
903 |
+
if layer.cross_attend:
|
904 |
+
# x = layer(x, mask, global_cond=global_cond, rotary_pos_emb=rotary_pos_emb, context=context, context_mask=context_mask,**kwargs)
|
905 |
+
if self.gradient_checkpointing:
|
906 |
+
x = checkpoint(layer, x, mask, global_cond, context, context_mask, rotary_pos_emb=rotary_pos_emb, **kwargs)
|
907 |
+
else:
|
908 |
+
x = layer(x, mask, global_cond, context, context_mask, rotary_pos_emb=rotary_pos_emb, **kwargs)
|
909 |
+
else:
|
910 |
+
# x = layer(x, mask, global_cond=global_cond, rotary_pos_emb=rotary_pos_emb, **kwargs)
|
911 |
+
if self.gradient_checkpointing:
|
912 |
+
x = checkpoint(layer, x, mask, global_cond, rotary_pos_emb=rotary_pos_emb, **kwargs)
|
913 |
+
else:
|
914 |
+
x = layer(x, mask, global_cond, rotary_pos_emb=rotary_pos_emb, **kwargs)
|
915 |
+
if return_info:
|
916 |
+
info["hidden_states"].append(x)
|
917 |
+
|
918 |
+
x = self.project_out(x)
|
919 |
+
|
920 |
+
if return_info:
|
921 |
+
return x, info
|
922 |
+
|
923 |
+
return x
|
924 |
+
|
925 |
+
def gradient_checkpointing_enable(self):
|
926 |
+
self.gradient_checkpointing = True
|
927 |
+
|
928 |
+
|
929 |
+
def _init_weights(self, module, init_std=0.02):
|
930 |
+
if isinstance(module, nn.Linear):
|
931 |
+
module.weight.data.normal_(mean=0.0, std=init_std)
|
932 |
+
if module.bias is not None:
|
933 |
+
module.bias.data.zero_()
|
934 |
+
elif isinstance(module, nn.Embedding):
|
935 |
+
module.weight.data.normal_(mean=0.0, std=init_std)
|
936 |
+
if module.padding_idx is not None:
|
937 |
+
module.weight.data[module.padding_idx].zero_()
|
SongBloom/models/vae_frontend/__init__.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
import typing as tp
|
4 |
+
import torchaudio
|
5 |
+
import einops
|
6 |
+
from abc import ABC, abstractmethod
|
7 |
+
|
8 |
+
|
9 |
+
class AbstractVAE(ABC, nn.Module):
|
10 |
+
|
11 |
+
@property
|
12 |
+
@abstractmethod
|
13 |
+
def frame_rate(self) -> float:
|
14 |
+
...
|
15 |
+
|
16 |
+
@property
|
17 |
+
@abstractmethod
|
18 |
+
def orig_sample_rate(self) -> int:
|
19 |
+
...
|
20 |
+
|
21 |
+
|
22 |
+
@property
|
23 |
+
@abstractmethod
|
24 |
+
def channel_dim(self) -> int:
|
25 |
+
...
|
26 |
+
|
27 |
+
@property
|
28 |
+
@abstractmethod
|
29 |
+
def split_bands(self) -> int:
|
30 |
+
...
|
31 |
+
|
32 |
+
@property
|
33 |
+
@abstractmethod
|
34 |
+
def input_channel(self) -> int:
|
35 |
+
...
|
36 |
+
|
37 |
+
|
38 |
+
def encode(self, wav) -> torch.Tensor:
|
39 |
+
...
|
40 |
+
|
41 |
+
def decode(self, latents) -> torch.Tensor:
|
42 |
+
...
|
43 |
+
|
44 |
+
|
45 |
+
from .autoencoders import create_autoencoder_from_config, AudioAutoencoder
|
46 |
+
class StableVAE(AbstractVAE):
|
47 |
+
def __init__(self, vae_ckpt, vae_cfg, sr=48000) -> None:
|
48 |
+
super().__init__()
|
49 |
+
import json
|
50 |
+
with open(vae_cfg) as f:
|
51 |
+
config = json.load(f)
|
52 |
+
self.vae: AudioAutoencoder = create_autoencoder_from_config(config)
|
53 |
+
self.vae.load_state_dict(torch.load(vae_ckpt)['state_dict'])
|
54 |
+
self.sample_rate = sr
|
55 |
+
self.rsp48k = torchaudio.transforms.Resample(sr, self.orig_sample_rate) if sr != self.orig_sample_rate else nn.Identity()
|
56 |
+
|
57 |
+
@torch.no_grad()
|
58 |
+
def encode(self, wav: torch.Tensor, sample=True) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
|
59 |
+
wav = self.rsp48k(wav)
|
60 |
+
if wav.shape[-1] < 2048:
|
61 |
+
return torch.zeros((wav.shape[0], self.channel_dim, 0), device=wav.device, dtype=wav.dtype)
|
62 |
+
if wav.ndim == 2:
|
63 |
+
wav = wav.unsqueeze(1)
|
64 |
+
if wav.shape[1] == 1:
|
65 |
+
wav = wav.repeat(1, self.vae.in_channels, 1)
|
66 |
+
latent = self.vae.encode_audio(wav) # B, 64, T
|
67 |
+
return latent
|
68 |
+
|
69 |
+
|
70 |
+
|
71 |
+
def decode(self, latents: torch.Tensor, **kwargs):
|
72 |
+
# B, 64, T
|
73 |
+
with torch.no_grad():
|
74 |
+
audio_recon = self.vae.decode_audio(latents, **kwargs)
|
75 |
+
|
76 |
+
return audio_recon
|
77 |
+
|
78 |
+
@property
|
79 |
+
def frame_rate(self) -> float:
|
80 |
+
return float(self.vae.sample_rate) / self.vae.downsampling_ratio
|
81 |
+
|
82 |
+
@property
|
83 |
+
def orig_sample_rate(self) -> int:
|
84 |
+
return self.vae.sample_rate
|
85 |
+
|
86 |
+
@property
|
87 |
+
def channel_dim(self) -> int:
|
88 |
+
return self.vae.latent_dim
|
89 |
+
|
90 |
+
@property
|
91 |
+
def split_bands(self) -> int:
|
92 |
+
return 1
|
93 |
+
|
94 |
+
@property
|
95 |
+
def input_channel(self) -> int:
|
96 |
+
return self.vae.in_channels
|
SongBloom/models/vae_frontend/autoencoders.py
ADDED
@@ -0,0 +1,657 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# https://github.com/Stability-AI/stable-audio-tools/tree/main/stable_audio_tools/models
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import math
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
from torch import nn
|
8 |
+
from torch.nn import functional as F
|
9 |
+
from torchaudio import transforms as T
|
10 |
+
from dac.nn.layers import WNConv1d, WNConvTranspose1d
|
11 |
+
from typing import Literal, Dict, Any
|
12 |
+
import os,sys
|
13 |
+
|
14 |
+
sys.path.insert(0, os.path.dirname(__file__))
|
15 |
+
from bottleneck import create_bottleneck_from_config
|
16 |
+
|
17 |
+
class Bottleneck(nn.Module):
|
18 |
+
def __init__(self, is_discrete: bool = False):
|
19 |
+
super().__init__()
|
20 |
+
|
21 |
+
self.is_discrete = is_discrete
|
22 |
+
|
23 |
+
def encode(self, x, return_info=False, **kwargs):
|
24 |
+
raise NotImplementedError
|
25 |
+
|
26 |
+
def decode(self, x):
|
27 |
+
raise NotImplementedError
|
28 |
+
|
29 |
+
class DiscreteBottleneck(Bottleneck):
|
30 |
+
def __init__(self, num_quantizers, codebook_size, tokens_id):
|
31 |
+
super().__init__(is_discrete=True)
|
32 |
+
|
33 |
+
self.num_quantizers = num_quantizers
|
34 |
+
self.codebook_size = codebook_size
|
35 |
+
self.tokens_id = tokens_id
|
36 |
+
|
37 |
+
def decode_tokens(self, codes, **kwargs):
|
38 |
+
raise NotImplementedError
|
39 |
+
|
40 |
+
|
41 |
+
def checkpoint(function, *args, **kwargs):
|
42 |
+
kwargs.setdefault("use_reentrant", False)
|
43 |
+
return torch.utils.checkpoint.checkpoint(function, *args, **kwargs)
|
44 |
+
|
45 |
+
|
46 |
+
|
47 |
+
# Adapted from https://github.com/NVIDIA/BigVGAN/blob/main/activations.py under MIT license
|
48 |
+
# License available in LICENSES/LICENSE_NVIDIA.txt
|
49 |
+
def snake_beta(x, alpha, beta):
|
50 |
+
return x + (1.0 / (beta + 1e-9)) * pow(torch.sin(x * alpha), 2)
|
51 |
+
|
52 |
+
class SnakeBeta(nn.Module):
|
53 |
+
|
54 |
+
def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True):
|
55 |
+
super(SnakeBeta, self).__init__()
|
56 |
+
self.in_features = in_features
|
57 |
+
|
58 |
+
# initialize alpha
|
59 |
+
self.alpha_logscale = alpha_logscale
|
60 |
+
if self.alpha_logscale: # log scale alphas initialized to zeros
|
61 |
+
self.alpha = nn.Parameter(torch.zeros(in_features) * alpha)
|
62 |
+
self.beta = nn.Parameter(torch.zeros(in_features) * alpha)
|
63 |
+
else: # linear scale alphas initialized to ones
|
64 |
+
self.alpha = nn.Parameter(torch.ones(in_features) * alpha)
|
65 |
+
self.beta = nn.Parameter(torch.ones(in_features) * alpha)
|
66 |
+
|
67 |
+
self.alpha.requires_grad = alpha_trainable
|
68 |
+
self.beta.requires_grad = alpha_trainable
|
69 |
+
|
70 |
+
self.no_div_by_zero = 1e-9
|
71 |
+
|
72 |
+
def forward(self, x):
|
73 |
+
alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
|
74 |
+
beta = self.beta.unsqueeze(0).unsqueeze(-1)
|
75 |
+
if self.alpha_logscale:
|
76 |
+
alpha = torch.exp(alpha)
|
77 |
+
beta = torch.exp(beta)
|
78 |
+
x = snake_beta(x, alpha, beta)
|
79 |
+
|
80 |
+
return x
|
81 |
+
|
82 |
+
def get_activation(activation: Literal["elu", "snake", "none"], antialias=False, channels=None) -> nn.Module:
|
83 |
+
if activation == "elu":
|
84 |
+
act = nn.ELU()
|
85 |
+
elif activation == "snake":
|
86 |
+
act = SnakeBeta(channels)
|
87 |
+
elif activation == "none":
|
88 |
+
act = nn.Identity()
|
89 |
+
else:
|
90 |
+
raise ValueError(f"Unknown activation {activation}")
|
91 |
+
|
92 |
+
if antialias:
|
93 |
+
from alias_free_torch import Activation1d
|
94 |
+
act = Activation1d(act)
|
95 |
+
|
96 |
+
return act
|
97 |
+
|
98 |
+
class ResidualUnit(nn.Module):
|
99 |
+
def __init__(self, in_channels, out_channels, dilation, use_snake=False, antialias_activation=False):
|
100 |
+
super().__init__()
|
101 |
+
|
102 |
+
self.dilation = dilation
|
103 |
+
|
104 |
+
padding = (dilation * (7-1)) // 2
|
105 |
+
|
106 |
+
self.layers = nn.Sequential(
|
107 |
+
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=out_channels),
|
108 |
+
WNConv1d(in_channels=in_channels, out_channels=out_channels,
|
109 |
+
kernel_size=7, dilation=dilation, padding=padding),
|
110 |
+
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=out_channels),
|
111 |
+
WNConv1d(in_channels=out_channels, out_channels=out_channels,
|
112 |
+
kernel_size=1)
|
113 |
+
)
|
114 |
+
|
115 |
+
def forward(self, x):
|
116 |
+
res = x
|
117 |
+
|
118 |
+
#x = checkpoint(self.layers, x)
|
119 |
+
x = self.layers(x)
|
120 |
+
|
121 |
+
return x + res
|
122 |
+
|
123 |
+
class EncoderBlock(nn.Module):
|
124 |
+
def __init__(self, in_channels, out_channels, stride, use_snake=False, antialias_activation=False):
|
125 |
+
super().__init__()
|
126 |
+
|
127 |
+
self.layers = nn.Sequential(
|
128 |
+
ResidualUnit(in_channels=in_channels,
|
129 |
+
out_channels=in_channels, dilation=1, use_snake=use_snake),
|
130 |
+
ResidualUnit(in_channels=in_channels,
|
131 |
+
out_channels=in_channels, dilation=3, use_snake=use_snake),
|
132 |
+
ResidualUnit(in_channels=in_channels,
|
133 |
+
out_channels=in_channels, dilation=9, use_snake=use_snake),
|
134 |
+
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=in_channels),
|
135 |
+
WNConv1d(in_channels=in_channels, out_channels=out_channels,
|
136 |
+
kernel_size=2*stride, stride=stride, padding=math.ceil(stride/2)),
|
137 |
+
)
|
138 |
+
|
139 |
+
def forward(self, x):
|
140 |
+
return self.layers(x)
|
141 |
+
|
142 |
+
class DecoderBlock(nn.Module):
|
143 |
+
def __init__(self, in_channels, out_channels, stride, use_snake=False, antialias_activation=False, use_nearest_upsample=False):
|
144 |
+
super().__init__()
|
145 |
+
|
146 |
+
if use_nearest_upsample:
|
147 |
+
upsample_layer = nn.Sequential(
|
148 |
+
nn.Upsample(scale_factor=stride, mode="nearest"),
|
149 |
+
WNConv1d(in_channels=in_channels,
|
150 |
+
out_channels=out_channels,
|
151 |
+
kernel_size=2*stride,
|
152 |
+
stride=1,
|
153 |
+
bias=False,
|
154 |
+
padding='same')
|
155 |
+
)
|
156 |
+
else:
|
157 |
+
upsample_layer = WNConvTranspose1d(in_channels=in_channels,
|
158 |
+
out_channels=out_channels,
|
159 |
+
kernel_size=2*stride, stride=stride, padding=math.ceil(stride/2))
|
160 |
+
|
161 |
+
self.layers = nn.Sequential(
|
162 |
+
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=in_channels),
|
163 |
+
upsample_layer,
|
164 |
+
ResidualUnit(in_channels=out_channels, out_channels=out_channels,
|
165 |
+
dilation=1, use_snake=use_snake),
|
166 |
+
ResidualUnit(in_channels=out_channels, out_channels=out_channels,
|
167 |
+
dilation=3, use_snake=use_snake),
|
168 |
+
ResidualUnit(in_channels=out_channels, out_channels=out_channels,
|
169 |
+
dilation=9, use_snake=use_snake),
|
170 |
+
)
|
171 |
+
|
172 |
+
def forward(self, x):
|
173 |
+
return self.layers(x)
|
174 |
+
|
175 |
+
class OobleckEncoder(nn.Module):
|
176 |
+
def __init__(self,
|
177 |
+
in_channels=2,
|
178 |
+
channels=128,
|
179 |
+
latent_dim=32,
|
180 |
+
c_mults = [1, 2, 4, 8],
|
181 |
+
strides = [2, 4, 8, 8],
|
182 |
+
use_snake=False,
|
183 |
+
antialias_activation=False
|
184 |
+
):
|
185 |
+
super().__init__()
|
186 |
+
|
187 |
+
c_mults = [1] + c_mults
|
188 |
+
|
189 |
+
self.depth = len(c_mults)
|
190 |
+
|
191 |
+
layers = [
|
192 |
+
WNConv1d(in_channels=in_channels, out_channels=c_mults[0] * channels, kernel_size=7, padding=3)
|
193 |
+
]
|
194 |
+
|
195 |
+
for i in range(self.depth-1):
|
196 |
+
layers += [EncoderBlock(in_channels=c_mults[i]*channels, out_channels=c_mults[i+1]*channels, stride=strides[i], use_snake=use_snake)]
|
197 |
+
|
198 |
+
layers += [
|
199 |
+
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=c_mults[-1] * channels),
|
200 |
+
WNConv1d(in_channels=c_mults[-1]*channels, out_channels=latent_dim, kernel_size=3, padding=1)
|
201 |
+
]
|
202 |
+
|
203 |
+
self.layers = nn.Sequential(*layers)
|
204 |
+
|
205 |
+
def forward(self, x):
|
206 |
+
return self.layers(x)
|
207 |
+
|
208 |
+
|
209 |
+
class OobleckDecoder(nn.Module):
|
210 |
+
def __init__(self,
|
211 |
+
out_channels=2,
|
212 |
+
channels=128,
|
213 |
+
latent_dim=32,
|
214 |
+
c_mults = [1, 2, 4, 8],
|
215 |
+
strides = [2, 4, 8, 8],
|
216 |
+
use_snake=False,
|
217 |
+
antialias_activation=False,
|
218 |
+
use_nearest_upsample=False,
|
219 |
+
final_tanh=True):
|
220 |
+
super().__init__()
|
221 |
+
|
222 |
+
c_mults = [1] + c_mults
|
223 |
+
|
224 |
+
self.depth = len(c_mults)
|
225 |
+
|
226 |
+
layers = [
|
227 |
+
WNConv1d(in_channels=latent_dim, out_channels=c_mults[-1]*channels, kernel_size=7, padding=3),
|
228 |
+
]
|
229 |
+
|
230 |
+
for i in range(self.depth-1, 0, -1):
|
231 |
+
layers += [DecoderBlock(
|
232 |
+
in_channels=c_mults[i]*channels,
|
233 |
+
out_channels=c_mults[i-1]*channels,
|
234 |
+
stride=strides[i-1],
|
235 |
+
use_snake=use_snake,
|
236 |
+
antialias_activation=antialias_activation,
|
237 |
+
use_nearest_upsample=use_nearest_upsample
|
238 |
+
)
|
239 |
+
]
|
240 |
+
|
241 |
+
layers += [
|
242 |
+
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=c_mults[0] * channels),
|
243 |
+
WNConv1d(in_channels=c_mults[0] * channels, out_channels=out_channels, kernel_size=7, padding=3, bias=False),
|
244 |
+
nn.Tanh() if final_tanh else nn.Identity()
|
245 |
+
]
|
246 |
+
|
247 |
+
self.layers = nn.Sequential(*layers)
|
248 |
+
|
249 |
+
def forward(self, x):
|
250 |
+
return self.layers(x)
|
251 |
+
|
252 |
+
|
253 |
+
|
254 |
+
class AudioAutoencoder(nn.Module):
|
255 |
+
def __init__(
|
256 |
+
self,
|
257 |
+
encoder,
|
258 |
+
decoder,
|
259 |
+
latent_dim,
|
260 |
+
downsampling_ratio,
|
261 |
+
sample_rate,
|
262 |
+
io_channels=2,
|
263 |
+
bottleneck = None,
|
264 |
+
pretransform = None,
|
265 |
+
in_channels = None,
|
266 |
+
out_channels = None,
|
267 |
+
soft_clip = False
|
268 |
+
):
|
269 |
+
super().__init__()
|
270 |
+
|
271 |
+
self.downsampling_ratio = downsampling_ratio
|
272 |
+
self.sample_rate = sample_rate
|
273 |
+
|
274 |
+
self.latent_dim = latent_dim
|
275 |
+
self.io_channels = io_channels
|
276 |
+
self.in_channels = io_channels
|
277 |
+
self.out_channels = io_channels
|
278 |
+
|
279 |
+
self.min_length = self.downsampling_ratio
|
280 |
+
|
281 |
+
if in_channels is not None:
|
282 |
+
self.in_channels = in_channels
|
283 |
+
|
284 |
+
if out_channels is not None:
|
285 |
+
self.out_channels = out_channels
|
286 |
+
|
287 |
+
self.bottleneck = bottleneck
|
288 |
+
|
289 |
+
self.encoder = encoder
|
290 |
+
|
291 |
+
self.decoder = decoder
|
292 |
+
|
293 |
+
self.pretransform = pretransform
|
294 |
+
|
295 |
+
self.soft_clip = soft_clip
|
296 |
+
|
297 |
+
self.is_discrete = self.bottleneck is not None and self.bottleneck.is_discrete
|
298 |
+
|
299 |
+
def encode(self, audio, return_info=False, skip_pretransform=False, iterate_batch=False, **kwargs):
|
300 |
+
|
301 |
+
info = {}
|
302 |
+
|
303 |
+
if self.pretransform is not None and not skip_pretransform:
|
304 |
+
if self.pretransform.enable_grad:
|
305 |
+
if iterate_batch:
|
306 |
+
audios = []
|
307 |
+
for i in range(audio.shape[0]):
|
308 |
+
audios.append(self.pretransform.encode(audio[i:i+1]))
|
309 |
+
audio = torch.cat(audios, dim=0)
|
310 |
+
else:
|
311 |
+
audio = self.pretransform.encode(audio)
|
312 |
+
else:
|
313 |
+
with torch.no_grad():
|
314 |
+
if iterate_batch:
|
315 |
+
audios = []
|
316 |
+
for i in range(audio.shape[0]):
|
317 |
+
audios.append(self.pretransform.encode(audio[i:i+1]))
|
318 |
+
audio = torch.cat(audios, dim=0)
|
319 |
+
else:
|
320 |
+
audio = self.pretransform.encode(audio)
|
321 |
+
|
322 |
+
if self.encoder is not None:
|
323 |
+
if iterate_batch:
|
324 |
+
latents = []
|
325 |
+
for i in range(audio.shape[0]):
|
326 |
+
latents.append(self.encoder(audio[i:i+1]))
|
327 |
+
latents = torch.cat(latents, dim=0)
|
328 |
+
else:
|
329 |
+
latents = self.encoder(audio)
|
330 |
+
else:
|
331 |
+
latents = audio
|
332 |
+
|
333 |
+
if self.bottleneck is not None:
|
334 |
+
# TODO: Add iterate batch logic, needs to merge the info dicts
|
335 |
+
latents, bottleneck_info = self.bottleneck.encode(latents, return_info=True, **kwargs)
|
336 |
+
|
337 |
+
info.update(bottleneck_info)
|
338 |
+
|
339 |
+
if return_info:
|
340 |
+
return latents, info
|
341 |
+
|
342 |
+
return latents
|
343 |
+
|
344 |
+
def decode(self, latents, iterate_batch=False, **kwargs):
|
345 |
+
|
346 |
+
if self.bottleneck is not None:
|
347 |
+
if iterate_batch:
|
348 |
+
decoded = []
|
349 |
+
for i in range(latents.shape[0]):
|
350 |
+
decoded.append(self.bottleneck.decode(latents[i:i+1]))
|
351 |
+
latents = torch.cat(decoded, dim=0)
|
352 |
+
else:
|
353 |
+
latents = self.bottleneck.decode(latents)
|
354 |
+
|
355 |
+
if iterate_batch:
|
356 |
+
decoded = []
|
357 |
+
for i in range(latents.shape[0]):
|
358 |
+
decoded.append(self.decoder(latents[i:i+1]))
|
359 |
+
decoded = torch.cat(decoded, dim=0)
|
360 |
+
else:
|
361 |
+
decoded = self.decoder(latents, **kwargs)
|
362 |
+
|
363 |
+
if self.pretransform is not None:
|
364 |
+
if self.pretransform.enable_grad:
|
365 |
+
if iterate_batch:
|
366 |
+
decodeds = []
|
367 |
+
for i in range(decoded.shape[0]):
|
368 |
+
decodeds.append(self.pretransform.decode(decoded[i:i+1]))
|
369 |
+
decoded = torch.cat(decodeds, dim=0)
|
370 |
+
else:
|
371 |
+
decoded = self.pretransform.decode(decoded)
|
372 |
+
else:
|
373 |
+
with torch.no_grad():
|
374 |
+
if iterate_batch:
|
375 |
+
decodeds = []
|
376 |
+
for i in range(latents.shape[0]):
|
377 |
+
decodeds.append(self.pretransform.decode(decoded[i:i+1]))
|
378 |
+
decoded = torch.cat(decodeds, dim=0)
|
379 |
+
else:
|
380 |
+
decoded = self.pretransform.decode(decoded)
|
381 |
+
|
382 |
+
if self.soft_clip:
|
383 |
+
decoded = torch.tanh(decoded)
|
384 |
+
|
385 |
+
return decoded
|
386 |
+
|
387 |
+
def decode_tokens(self, tokens, **kwargs):
|
388 |
+
'''
|
389 |
+
Decode discrete tokens to audio
|
390 |
+
Only works with discrete autoencoders
|
391 |
+
'''
|
392 |
+
|
393 |
+
assert isinstance(self.bottleneck, DiscreteBottleneck), "decode_tokens only works with discrete autoencoders"
|
394 |
+
|
395 |
+
latents = self.bottleneck.decode_tokens(tokens, **kwargs)
|
396 |
+
|
397 |
+
return self.decode(latents, **kwargs)
|
398 |
+
|
399 |
+
|
400 |
+
|
401 |
+
def encode_audio(self, audio, chunked=False, overlap=32, chunk_size=128, **kwargs):
|
402 |
+
'''
|
403 |
+
Encode audios into latents. Audios should already be preprocesed by preprocess_audio_for_encoder.
|
404 |
+
If chunked is True, split the audio into chunks of a given maximum size chunk_size, with given overlap.
|
405 |
+
Overlap and chunk_size params are both measured in number of latents (not audio samples)
|
406 |
+
# and therefore you likely could use the same values with decode_audio.
|
407 |
+
A overlap of zero will cause discontinuity artefacts. Overlap should be => receptive field size.
|
408 |
+
Every autoencoder will have a different receptive field size, and thus ideal overlap.
|
409 |
+
You can determine it empirically by diffing unchunked vs chunked output and looking at maximum diff.
|
410 |
+
The final chunk may have a longer overlap in order to keep chunk_size consistent for all chunks.
|
411 |
+
Smaller chunk_size uses less memory, but more compute.
|
412 |
+
The chunk_size vs memory tradeoff isn't linear, and possibly depends on the GPU and CUDA version
|
413 |
+
For example, on a A6000 chunk_size 128 is overall faster than 256 and 512 even though it has more chunks
|
414 |
+
'''
|
415 |
+
if not chunked:
|
416 |
+
# default behavior. Encode the entire audio in parallel
|
417 |
+
return self.encode(audio, **kwargs)
|
418 |
+
else:
|
419 |
+
# CHUNKED ENCODING
|
420 |
+
# samples_per_latent is just the downsampling ratio (which is also the upsampling ratio)
|
421 |
+
samples_per_latent = self.downsampling_ratio
|
422 |
+
total_size = audio.shape[2] # in samples
|
423 |
+
batch_size = audio.shape[0]
|
424 |
+
chunk_size *= samples_per_latent # converting metric in latents to samples
|
425 |
+
overlap *= samples_per_latent # converting metric in latents to samples
|
426 |
+
hop_size = chunk_size - overlap
|
427 |
+
chunks = []
|
428 |
+
for i in range(0, total_size - chunk_size + 1, hop_size):
|
429 |
+
chunk = audio[:,:,i:i+chunk_size]
|
430 |
+
chunks.append(chunk)
|
431 |
+
if i+chunk_size != total_size:
|
432 |
+
# Final chunk
|
433 |
+
chunk = audio[:,:,-chunk_size:]
|
434 |
+
chunks.append(chunk)
|
435 |
+
chunks = torch.stack(chunks)
|
436 |
+
num_chunks = chunks.shape[0]
|
437 |
+
# Note: y_size might be a different value from the latent length used in diffusion training
|
438 |
+
# because we can encode audio of varying lengths
|
439 |
+
# However, the audio should've been padded to a multiple of samples_per_latent by now.
|
440 |
+
y_size = total_size // samples_per_latent
|
441 |
+
# Create an empty latent, we will populate it with chunks as we encode them
|
442 |
+
y_final = torch.zeros((batch_size,self.latent_dim,y_size)).to(audio.device)
|
443 |
+
for i in range(num_chunks):
|
444 |
+
x_chunk = chunks[i,:]
|
445 |
+
# encode the chunk
|
446 |
+
y_chunk = self.encode(x_chunk)
|
447 |
+
# figure out where to put the audio along the time domain
|
448 |
+
if i == num_chunks-1:
|
449 |
+
# final chunk always goes at the end
|
450 |
+
t_end = y_size
|
451 |
+
t_start = t_end - y_chunk.shape[2]
|
452 |
+
else:
|
453 |
+
t_start = i * hop_size // samples_per_latent
|
454 |
+
t_end = t_start + chunk_size // samples_per_latent
|
455 |
+
# remove the edges of the overlaps
|
456 |
+
ol = overlap//samples_per_latent//2
|
457 |
+
chunk_start = 0
|
458 |
+
chunk_end = y_chunk.shape[2]
|
459 |
+
if i > 0:
|
460 |
+
# no overlap for the start of the first chunk
|
461 |
+
t_start += ol
|
462 |
+
chunk_start += ol
|
463 |
+
if i < num_chunks-1:
|
464 |
+
# no overlap for the end of the last chunk
|
465 |
+
t_end -= ol
|
466 |
+
chunk_end -= ol
|
467 |
+
# paste the chunked audio into our y_final output audio
|
468 |
+
y_final[:,:,t_start:t_end] = y_chunk[:,:,chunk_start:chunk_end]
|
469 |
+
return y_final
|
470 |
+
|
471 |
+
def decode_audio(self, latents, chunked=False, overlap=32, chunk_size=128, **kwargs):
|
472 |
+
'''
|
473 |
+
Decode latents to audio.
|
474 |
+
If chunked is True, split the latents into chunks of a given maximum size chunk_size, with given overlap, both of which are measured in number of latents.
|
475 |
+
A overlap of zero will cause discontinuity artefacts. Overlap should be => receptive field size.
|
476 |
+
Every autoencoder will have a different receptive field size, and thus ideal overlap.
|
477 |
+
You can determine it empirically by diffing unchunked vs chunked audio and looking at maximum diff.
|
478 |
+
The final chunk may have a longer overlap in order to keep chunk_size consistent for all chunks.
|
479 |
+
Smaller chunk_size uses less memory, but more compute.
|
480 |
+
The chunk_size vs memory tradeoff isn't linear, and possibly depends on the GPU and CUDA version
|
481 |
+
For example, on a A6000 chunk_size 128 is overall faster than 256 and 512 even though it has more chunks
|
482 |
+
'''
|
483 |
+
if not chunked:
|
484 |
+
# default behavior. Decode the entire latent in parallel
|
485 |
+
return self.decode(latents, **kwargs)
|
486 |
+
else:
|
487 |
+
# chunked decoding
|
488 |
+
hop_size = chunk_size - overlap
|
489 |
+
total_size = latents.shape[2]
|
490 |
+
batch_size = latents.shape[0]
|
491 |
+
chunks = []
|
492 |
+
for i in range(0, total_size - chunk_size + 1, hop_size):
|
493 |
+
chunk = latents[:,:,i:i+chunk_size]
|
494 |
+
chunks.append(chunk)
|
495 |
+
if i+chunk_size != total_size:
|
496 |
+
# Final chunk
|
497 |
+
chunk = latents[:,:,-chunk_size:]
|
498 |
+
chunks.append(chunk)
|
499 |
+
chunks = torch.stack(chunks)
|
500 |
+
num_chunks = chunks.shape[0]
|
501 |
+
# samples_per_latent is just the downsampling ratio
|
502 |
+
samples_per_latent = self.downsampling_ratio
|
503 |
+
# Create an empty waveform, we will populate it with chunks as decode them
|
504 |
+
y_size = total_size * samples_per_latent
|
505 |
+
y_final = torch.zeros((batch_size,self.out_channels,y_size)).to(latents.device)
|
506 |
+
for i in range(num_chunks):
|
507 |
+
x_chunk = chunks[i,:]
|
508 |
+
# decode the chunk
|
509 |
+
y_chunk = self.decode(x_chunk)
|
510 |
+
# figure out where to put the audio along the time domain
|
511 |
+
if i == num_chunks-1:
|
512 |
+
# final chunk always goes at the end
|
513 |
+
t_end = y_size
|
514 |
+
t_start = t_end - y_chunk.shape[2]
|
515 |
+
else:
|
516 |
+
t_start = i * hop_size * samples_per_latent
|
517 |
+
t_end = t_start + chunk_size * samples_per_latent
|
518 |
+
# remove the edges of the overlaps
|
519 |
+
ol = (overlap//2) * samples_per_latent
|
520 |
+
chunk_start = 0
|
521 |
+
chunk_end = y_chunk.shape[2]
|
522 |
+
if i > 0:
|
523 |
+
# no overlap for the start of the first chunk
|
524 |
+
t_start += ol
|
525 |
+
chunk_start += ol
|
526 |
+
if i < num_chunks-1:
|
527 |
+
# no overlap for the end of the last chunk
|
528 |
+
t_end -= ol
|
529 |
+
chunk_end -= ol
|
530 |
+
# paste the chunked audio into our y_final output audio
|
531 |
+
y_final[:,:,t_start:t_end] = y_chunk[:,:,chunk_start:chunk_end]
|
532 |
+
return y_final
|
533 |
+
|
534 |
+
|
535 |
+
# AE factories
|
536 |
+
|
537 |
+
def create_encoder_from_config(encoder_config: Dict[str, Any]):
|
538 |
+
encoder_type = encoder_config.get("type", None)
|
539 |
+
assert encoder_type is not None, "Encoder type must be specified"
|
540 |
+
|
541 |
+
if encoder_type == "oobleck":
|
542 |
+
encoder = OobleckEncoder(
|
543 |
+
**encoder_config["config"]
|
544 |
+
)
|
545 |
+
|
546 |
+
elif encoder_type == "seanet":
|
547 |
+
from encodec.modules import SEANetEncoder
|
548 |
+
seanet_encoder_config = encoder_config["config"]
|
549 |
+
|
550 |
+
#SEANet encoder expects strides in reverse order
|
551 |
+
seanet_encoder_config["ratios"] = list(reversed(seanet_encoder_config.get("ratios", [2, 2, 2, 2, 2])))
|
552 |
+
encoder = SEANetEncoder(
|
553 |
+
**seanet_encoder_config
|
554 |
+
)
|
555 |
+
else:
|
556 |
+
raise ValueError(f"Unknown encoder type {encoder_type}")
|
557 |
+
|
558 |
+
requires_grad = encoder_config.get("requires_grad", True)
|
559 |
+
if not requires_grad:
|
560 |
+
for param in encoder.parameters():
|
561 |
+
param.requires_grad = False
|
562 |
+
|
563 |
+
return encoder
|
564 |
+
|
565 |
+
def create_decoder_from_config(decoder_config: Dict[str, Any]):
|
566 |
+
decoder_type = decoder_config.get("type", None)
|
567 |
+
assert decoder_type is not None, "Decoder type must be specified"
|
568 |
+
|
569 |
+
if decoder_type == "oobleck":
|
570 |
+
decoder = OobleckDecoder(
|
571 |
+
**decoder_config["config"]
|
572 |
+
)
|
573 |
+
elif decoder_type == "seanet":
|
574 |
+
from encodec.modules import SEANetDecoder
|
575 |
+
|
576 |
+
decoder = SEANetDecoder(
|
577 |
+
**decoder_config["config"]
|
578 |
+
)
|
579 |
+
else:
|
580 |
+
raise ValueError(f"Unknown decoder type {decoder_type}")
|
581 |
+
|
582 |
+
requires_grad = decoder_config.get("requires_grad", True)
|
583 |
+
if not requires_grad:
|
584 |
+
for param in decoder.parameters():
|
585 |
+
param.requires_grad = False
|
586 |
+
|
587 |
+
return decoder
|
588 |
+
|
589 |
+
def create_autoencoder_from_config(config: Dict[str, Any]):
|
590 |
+
|
591 |
+
# print(config)
|
592 |
+
ae_config = config["model"]
|
593 |
+
|
594 |
+
encoder = create_encoder_from_config(ae_config["encoder"])
|
595 |
+
decoder = create_decoder_from_config(ae_config["decoder"])
|
596 |
+
|
597 |
+
bottleneck = ae_config.get("bottleneck", None)
|
598 |
+
|
599 |
+
latent_dim = ae_config.get("latent_dim", None)
|
600 |
+
assert latent_dim is not None, "latent_dim must be specified in model config"
|
601 |
+
downsampling_ratio = ae_config.get("downsampling_ratio", None)
|
602 |
+
assert downsampling_ratio is not None, "downsampling_ratio must be specified in model config"
|
603 |
+
io_channels = ae_config.get("io_channels", None)
|
604 |
+
assert io_channels is not None, "io_channels must be specified in model config"
|
605 |
+
sample_rate = config.get("sample_rate", None)
|
606 |
+
assert sample_rate is not None, "sample_rate must be specified in model config"
|
607 |
+
|
608 |
+
in_channels = ae_config.get("in_channels", None)
|
609 |
+
out_channels = ae_config.get("out_channels", None)
|
610 |
+
|
611 |
+
pretransform = ae_config.get("pretransform", None)
|
612 |
+
|
613 |
+
if pretransform is not None:
|
614 |
+
from stable_audio_tools.models.factory import create_pretransform_from_config
|
615 |
+
pretransform = create_pretransform_from_config(pretransform, sample_rate)
|
616 |
+
|
617 |
+
if bottleneck is not None:
|
618 |
+
bottleneck = create_bottleneck_from_config(bottleneck)
|
619 |
+
|
620 |
+
soft_clip = ae_config["decoder"].get("soft_clip", False)
|
621 |
+
|
622 |
+
return AudioAutoencoder(
|
623 |
+
encoder,
|
624 |
+
decoder,
|
625 |
+
io_channels=io_channels,
|
626 |
+
latent_dim=latent_dim,
|
627 |
+
downsampling_ratio=downsampling_ratio,
|
628 |
+
sample_rate=sample_rate,
|
629 |
+
bottleneck=bottleneck,
|
630 |
+
pretransform=pretransform,
|
631 |
+
in_channels=in_channels,
|
632 |
+
out_channels=out_channels,
|
633 |
+
soft_clip=soft_clip
|
634 |
+
)
|
635 |
+
|
636 |
+
|
637 |
+
|
638 |
+
if __name__ == "__main__":
|
639 |
+
import json
|
640 |
+
import torchaudio
|
641 |
+
config_path = 'modelzoo/stable_audio_vae/stable_audio_2_0_vae.json'
|
642 |
+
with open(config_path) as f:
|
643 |
+
config = json.load(f)
|
644 |
+
with torch.no_grad():
|
645 |
+
vae_model = create_autoencoder_from_config(config).cuda()
|
646 |
+
model_ckpt_path = 'modelzoo/stable_audio_vae/autoencoder.ckpt'
|
647 |
+
vae_model.load_state_dict(torch.load(model_ckpt_path)['state_dict'])
|
648 |
+
|
649 |
+
|
650 |
+
input_audios, sr = torchaudio.load("music_example/加勒比海盗 主题.wav")
|
651 |
+
input_audios = torchaudio.functional.resample(input_audios, sr, 48000)[...,:2048]
|
652 |
+
input_audios = input_audios.unsqueeze(1).repeat(1, 2, 1).cuda()
|
653 |
+
latents = vae_model.encode_audio(input_audios)
|
654 |
+
recover_audio = vae_model.decode_audio(latents)
|
655 |
+
print(recover_audio)
|
656 |
+
|
657 |
+
breakpoint()
|