mrfakename commited on
Commit
a0e2cb7
·
0 Parent(s):
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +38 -0
  2. .gitignore +215 -0
  3. README.md +83 -0
  4. SongBloom/g2p/__init__.py +0 -0
  5. SongBloom/g2p/cn_zh_g2p/__init__.py +106 -0
  6. SongBloom/g2p/cn_zh_g2p/chinese.py +173 -0
  7. SongBloom/g2p/cn_zh_g2p/cmudict-fast.rep +0 -0
  8. SongBloom/g2p/cn_zh_g2p/cmudict.rep +0 -0
  9. SongBloom/g2p/cn_zh_g2p/engdict-hot.rep +2 -0
  10. SongBloom/g2p/cn_zh_g2p/engdict_cache.pickle +3 -0
  11. SongBloom/g2p/cn_zh_g2p/english.py +369 -0
  12. SongBloom/g2p/cn_zh_g2p/nltk_data/corpora/cmudict.zip +3 -0
  13. SongBloom/g2p/cn_zh_g2p/nltk_data/corpora/cmudict/README +76 -0
  14. SongBloom/g2p/cn_zh_g2p/nltk_data/corpora/cmudict/cmudict +0 -0
  15. SongBloom/g2p/cn_zh_g2p/nltk_data/taggers/averaged_perceptron_tagger.zip +3 -0
  16. SongBloom/g2p/cn_zh_g2p/nltk_data/taggers/averaged_perceptron_tagger/averaged_perceptron_tagger.pickle +3 -0
  17. SongBloom/g2p/cn_zh_g2p/opencpop-strict.txt +429 -0
  18. SongBloom/g2p/cn_zh_g2p/symbols.py +401 -0
  19. SongBloom/g2p/cn_zh_g2p/tone_sandhi.py +806 -0
  20. SongBloom/g2p/cn_zh_g2p/zh_normalization/README.md +16 -0
  21. SongBloom/g2p/cn_zh_g2p/zh_normalization/__init__.py +14 -0
  22. SongBloom/g2p/cn_zh_g2p/zh_normalization/char_convert.py +46 -0
  23. SongBloom/g2p/cn_zh_g2p/zh_normalization/chronology.py +134 -0
  24. SongBloom/g2p/cn_zh_g2p/zh_normalization/constants.py +62 -0
  25. SongBloom/g2p/cn_zh_g2p/zh_normalization/num.py +282 -0
  26. SongBloom/g2p/cn_zh_g2p/zh_normalization/phonecode.py +63 -0
  27. SongBloom/g2p/cn_zh_g2p/zh_normalization/quantifier.py +63 -0
  28. SongBloom/g2p/cn_zh_g2p/zh_normalization/text_normlization.py +165 -0
  29. SongBloom/g2p/lyric_common.py +81 -0
  30. SongBloom/g2p/pinyin/__init__.py +430 -0
  31. SongBloom/g2p/pinyin/pinyin.py +137 -0
  32. SongBloom/g2p/pinyin/symbols.py +71 -0
  33. SongBloom/models/base/sample.py +57 -0
  34. SongBloom/models/base/utils.py +57 -0
  35. SongBloom/models/musicgen/__init__.py +0 -0
  36. SongBloom/models/musicgen/conditioners/__init__.py +37 -0
  37. SongBloom/models/musicgen/conditioners/base.py +872 -0
  38. SongBloom/models/musicgen/conditioners/text.py +254 -0
  39. SongBloom/models/musicgen/conditioners/wav.py +74 -0
  40. SongBloom/models/musicgen/get_backend.py +76 -0
  41. SongBloom/models/musicgen/modules/streaming.py +125 -0
  42. SongBloom/models/musicldm/__init__.py +0 -0
  43. SongBloom/models/musicldm/inference/__init__.py +0 -0
  44. SongBloom/models/musicldm/inference/sampling.py +271 -0
  45. SongBloom/models/musicldm/musicldm_dit.py +24 -0
  46. SongBloom/models/songbloom/songbloom_mvsa.py +572 -0
  47. SongBloom/models/songbloom/songbloom_pl.py +224 -0
  48. SongBloom/models/transformer.py +937 -0
  49. SongBloom/models/vae_frontend/__init__.py +96 -0
  50. SongBloom/models/vae_frontend/autoencoders.py +657 -0
.gitattributes ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.flac filter=lfs diff=lfs merge=lfs -text
37
+ *.wav filter=lfs diff=lfs merge=lfs -text
38
+ *.png filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[codz]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py.cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # UV
98
+ # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ #uv.lock
102
+
103
+ # poetry
104
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
105
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
106
+ # commonly ignored for libraries.
107
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
108
+ #poetry.lock
109
+ #poetry.toml
110
+
111
+ # pdm
112
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
113
+ # pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python.
114
+ # https://pdm-project.org/en/latest/usage/project/#working-with-version-control
115
+ #pdm.lock
116
+ #pdm.toml
117
+ .pdm-python
118
+ .pdm-build/
119
+
120
+ # pixi
121
+ # Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control.
122
+ #pixi.lock
123
+ # Pixi creates a virtual environment in the .pixi directory, just like venv module creates one
124
+ # in the .venv directory. It is recommended not to include this directory in version control.
125
+ .pixi
126
+
127
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
128
+ __pypackages__/
129
+
130
+ # Celery stuff
131
+ celerybeat-schedule
132
+ celerybeat.pid
133
+
134
+ # SageMath parsed files
135
+ *.sage.py
136
+
137
+ # Environments
138
+ .env
139
+ .envrc
140
+ .venv
141
+ env/
142
+ venv/
143
+ ENV/
144
+ env.bak/
145
+ venv.bak/
146
+
147
+ # Spyder project settings
148
+ .spyderproject
149
+ .spyproject
150
+
151
+ # Rope project settings
152
+ .ropeproject
153
+
154
+ # mkdocs documentation
155
+ /site
156
+
157
+ # mypy
158
+ .mypy_cache/
159
+ .dmypy.json
160
+ dmypy.json
161
+
162
+ # Pyre type checker
163
+ .pyre/
164
+
165
+ # pytype static type analyzer
166
+ .pytype/
167
+
168
+ # Cython debug symbols
169
+ cython_debug/
170
+
171
+ # PyCharm
172
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
173
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
174
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
175
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
176
+ #.idea/
177
+
178
+ # Abstra
179
+ # Abstra is an AI-powered process automation framework.
180
+ # Ignore directories containing user credentials, local state, and settings.
181
+ # Learn more at https://abstra.io/docs
182
+ .abstra/
183
+
184
+ # Visual Studio Code
185
+ # Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore
186
+ # that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore
187
+ # and can be added to the global gitignore or merged into this file. However, if you prefer,
188
+ # you could uncomment the following to ignore the entire vscode folder
189
+ # .vscode/
190
+
191
+ # Ruff stuff:
192
+ .ruff_cache/
193
+
194
+ # PyPI configuration file
195
+ .pypirc
196
+
197
+ # Cursor
198
+ # Cursor is an AI-powered code editor. `.cursorignore` specifies files/directories to
199
+ # exclude from AI features like autocomplete and code analysis. Recommended for sensitive data
200
+ # refer to https://docs.cursor.com/context/ignore-files
201
+ .cursorignore
202
+ .cursorindexingignore
203
+
204
+ # Marimo
205
+ marimo/_static/
206
+ marimo/_lsp/
207
+ __marimo__/
208
+
209
+ # Streamlit
210
+ .streamlit/secrets.toml
211
+
212
+ __pycache__
213
+ output
214
+ cache
215
+ third_party
README.md ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # [SongBloom]: *Coherent Song Generation via Interleaved Autoregressive Sketching and Diffusion Refinement*
2
+
3
+ We propose **SongBloom**, a novel framework for full-length song generation that leverages an interleaved paradigm of autoregressive sketching and diffusion-based refinement. SongBloom employs an autoregressive diffusion model that combines the high fidelity of diffusion models with the scalability of language models.
4
+ Specifically, it gradually extends a musical sketch from short to long and refines the details from coarse to fine-grained. The interleaved generation paradigm effectively integrates prior semantic and acoustic context to guide the generation process.
5
+ Experimental results demonstrate that SongBloom outperforms existing methods across both subjective and objective metrics and achieves performance comparable to the state-of-the-art commercial music generation platforms.
6
+
7
+ ![img](docs/architecture.png)
8
+
9
+ Demo page: [https://cypress-yang.github.io/SongBloom_demo](https://cypress-yang.github.io/SongBloom_demo)
10
+
11
+ ArXiv: [https://arxiv.org/abs/2506.07634](https://arxiv.org/abs/2506.07634)
12
+
13
+ ## Prepare Environments
14
+
15
+ ```bash
16
+ conda create -n SongBloom python==3.8.12
17
+ conda activate SongBloom
18
+
19
+ # yum install libsndfile
20
+ # pip install torch==2.2.0 torchaudio==2.2.0 --index-url https://download.pytorch.org/whl/cu118 # For different CUDA version
21
+ pip install -r requirements.txt
22
+ ```
23
+
24
+ ## Data Preparation
25
+
26
+ A .jsonl file, where each line is a json object:
27
+
28
+ ```json
29
+ {
30
+ "idx": "The index of each sample",
31
+ "lyrics": "The lyrics to be generated",
32
+ "prompt_wav": "The path of the style prompt audio",
33
+ }
34
+ ```
35
+
36
+ One example can be refered to as: [example/test.jsonl](example/test.jsonl)
37
+
38
+ The prompt wav should be a 10-second, 48kHz audio clip.
39
+
40
+ The details about lyric format can be found in [docs/lyric_format.md](docs/lyric_format.md).
41
+
42
+ ## Inference
43
+
44
+ ```bash
45
+ source set_env.sh
46
+
47
+ python3 infer.py --input-jsonl example/test.jsonl
48
+
49
+ # For GPUs with low VRAM like RTX4090, you should set the dtype as bfloat16
50
+ python3 infer.py --input-jsonl example/test.jsonl --dtype bfloat16
51
+
52
+ # SongBloom also supports flash-attn (optional). To enable it, please install flash-attn (v2.6.3 is used during training) manually and set os.environ['DISABLE_FLASH_ATTN'] = "0" in infer.py:8
53
+ ```
54
+
55
+ ## Models
56
+
57
+ | Name | Size | Max Length | Prompt type | 🤗 |
58
+ | -------------------- | ---- | ---------- | ----------- | -------------------------------------------- |
59
+ | songbloom_full_150s | 2B | 2m30s | 10s wav | [link](https://huggingface.co/CypressYang/SongBloom) |
60
+ | songbloom_mulan_150s | 2B | 2m30s | 10s wav / text description | coming soon |
61
+ | ... | | | | |
62
+
63
+
64
+
65
+ ## TODO List
66
+
67
+ - [ ] Support Text Description
68
+ - [ ] Full version
69
+
70
+ ## Citation
71
+
72
+ ```
73
+ @article{yang2025songbloom,
74
+ title={SongBloom: Coherent Song Generation via Interleaved Autoregressive Sketching and Diffusion Refinement},
75
+ author={Yang, Chenyu and Wang, Shuai and Chen, Hangting and Tan, Wei and Yu, Jianwei and Li, Haizhou},
76
+ journal={arXiv preprint arXiv:2506.07634},
77
+ year={2025}
78
+ }
79
+ ```
80
+
81
+ ## License
82
+
83
+ SongBloom (codes and weights) is released under the [Apache License 2.0](https://www.apache.org/licenses/LICENSE-2.0).
SongBloom/g2p/__init__.py ADDED
File without changes
SongBloom/g2p/cn_zh_g2p/__init__.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from . import chinese, english # , japanese 暂时干掉看看
2
+ from .symbols import *
3
+ import yaml
4
+ language_module_map = {"zh": chinese, "en": english} #, "ja": japanese
5
+
6
+ def is_chinese(uchar):
7
+ if uchar >= u'\u4e00' and uchar <= u'\u9fa5':
8
+ return True
9
+ else:
10
+ return False
11
+
12
+ import re
13
+
14
+ # def split_text(text):
15
+ # chinese_pattern = r'[\u4e00-\u9fa5][\u4e00-\u9fa5\ \,\.\!\?\,\。]+'
16
+ # english_pattern = r'[a-zA-Z][a-zA-Z\'\ \,\.\!\?]+'
17
+
18
+ # chinese_text = re.findall(chinese_pattern, text)
19
+ # print(chinese_text)
20
+ # english_text = re.findall(english_pattern, text)
21
+
22
+ # return chinese_text, english_text
23
+
24
+ def split_text(text):
25
+ pattern = re.compile("|".join(re.escape(p) for p in chinese.rep_map.keys()))
26
+ text = pattern.sub(lambda x: chinese.rep_map[x.group()], text)
27
+
28
+ result = []
29
+ lang = []
30
+ buffer = ""
31
+ chinese_pattern = r'[\u4e00-\u9fa5]'
32
+ special_pattern = r'[\,\.\!\?\…\-]'
33
+ # TODO check 一下
34
+ for char in text:
35
+ if re.match(special_pattern, char):
36
+ if buffer:
37
+ if not re.match(chinese_pattern, buffer[0]):
38
+ result.append(buffer)
39
+ lang.append('en')
40
+ else:
41
+ result.append(buffer)
42
+ lang.append("zh")
43
+ result.append(char)
44
+ lang.append('sp')
45
+ buffer = ""
46
+
47
+
48
+ elif re.match(chinese_pattern, char):
49
+ if buffer and not re.match(chinese_pattern, buffer[-1]):
50
+ result.append(buffer)
51
+ buffer = ""
52
+ lang.append('en')
53
+ buffer += char
54
+ else:
55
+ if buffer and re.match(chinese_pattern, buffer[-1]):
56
+ result.append(buffer)
57
+ buffer = ""
58
+ lang.append("zh")
59
+ buffer += char
60
+
61
+ if buffer:
62
+ result.append(buffer)
63
+ lang.append("zh" if re.match(chinese_pattern, buffer[-1]) else 'en')
64
+
65
+ return result, lang
66
+
67
+ def mixed_language_to_phoneme(text):
68
+ segments, lang = split_text(text)
69
+ # print(segments, lang)
70
+ result = [language_to_phoneme(s, l) for s, l in zip(segments, lang)]
71
+ phones, word2ph = [], []
72
+ for p, w, n in result:
73
+ phones += p
74
+ if w is None:
75
+ w = []
76
+ word2ph += w
77
+ return phones, word2ph
78
+
79
+
80
+ def language_to_phoneme(text, language):
81
+ if language == 'sp':
82
+ return [text], None, text
83
+ language_module = language_module_map[language]
84
+ norm_text = language_module.text_normalize(text)
85
+ if language == "zh":
86
+ phones, word2ph = language_module.g2p(norm_text)
87
+ assert len(phones) == sum(word2ph)
88
+ assert len(norm_text) == len(word2ph)
89
+ else:
90
+ try:
91
+ phones = language_module.g2p(norm_text)
92
+ except:
93
+ phones = [norm_text]
94
+ word2ph = None
95
+
96
+ # for ph in phones:
97
+ # assert ph in symbols, ph
98
+ return phones, word2ph, norm_text
99
+
100
+ def gen_vocabs():
101
+ yaml.dump(symbols, open('./vocab.yaml', 'w'))
102
+
103
+ class G2P_Mix():
104
+ def __call__(self, text):
105
+ phones, word2ph = mixed_language_to_phoneme(text)
106
+ return ' '.join(phones)
SongBloom/g2p/cn_zh_g2p/chinese.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pdb
3
+ import re
4
+
5
+ import cn2an
6
+ from pypinyin import lazy_pinyin, Style
7
+
8
+ from .symbols import punctuation
9
+ from .tone_sandhi import ToneSandhi
10
+ from .zh_normalization.text_normlization import TextNormalizer
11
+
12
+ normalizer = lambda x: cn2an.transform(x, "an2cn")
13
+
14
+ current_file_path = os.path.dirname(__file__)
15
+ pinyin_to_symbol_map = {
16
+ line.split("\t")[0]: line.strip().split("\t")[1]
17
+ for line in open(os.path.join(current_file_path, "opencpop-strict.txt")).readlines()
18
+ }
19
+
20
+ import jieba_fast.posseg as psg
21
+
22
+
23
+ rep_map = {
24
+ ":": ",",
25
+ ";": ",",
26
+ ",": ",",
27
+ "。": ".",
28
+ "!": "!",
29
+ "?": "?",
30
+ "\n": ".",
31
+ "·": ",",
32
+ "、": ",",
33
+ "...": "…",
34
+ "$": ".",
35
+ "/": ",",
36
+ "—": "-",
37
+ "~": "…",
38
+ "~":"…",
39
+ }
40
+
41
+ tone_modifier = ToneSandhi()
42
+
43
+
44
+ def replace_punctuation(text):
45
+ text = text.replace("嗯", "恩").replace("呣", "母")
46
+ pattern = re.compile("|".join(re.escape(p) for p in rep_map.keys()))
47
+
48
+ replaced_text = pattern.sub(lambda x: rep_map[x.group()], text)
49
+ replaced_text = re.sub(
50
+ r"[^\u4e00-\u9fa5" + "".join(punctuation) + r"]+", "", replaced_text
51
+ )
52
+
53
+ return replaced_text
54
+
55
+
56
+ def g2p(text):
57
+ pattern = r"(?<=[{0}])\s*".format("".join(punctuation))
58
+ sentences = [i for i in re.split(pattern, text) if i.strip() != ""]
59
+ phones, word2ph = _g2p(sentences)
60
+ return phones, word2ph
61
+
62
+
63
+ def _get_initials_finals(word):
64
+ initials = []
65
+ finals = []
66
+ orig_initials = lazy_pinyin(word, neutral_tone_with_five=True, style=Style.INITIALS)
67
+ orig_finals = lazy_pinyin(
68
+ word, neutral_tone_with_five=True, style=Style.FINALS_TONE3
69
+ )
70
+ for c, v in zip(orig_initials, orig_finals):
71
+ initials.append(c)
72
+ finals.append(v)
73
+ return initials, finals
74
+
75
+
76
+ def _g2p(segments):
77
+ phones_list = []
78
+ word2ph = []
79
+ for seg in segments:
80
+ pinyins = []
81
+ # Replace all English words in the sentence
82
+ seg = re.sub("[a-zA-Z]+", "", seg)
83
+ seg_cut = psg.lcut(seg)
84
+ initials = []
85
+ finals = []
86
+ seg_cut = tone_modifier.pre_merge_for_modify(seg_cut)
87
+ for word, pos in seg_cut:
88
+ if pos == "eng":
89
+ continue
90
+ sub_initials, sub_finals = _get_initials_finals(word)
91
+ sub_finals = tone_modifier.modified_tone(word, pos, sub_finals)
92
+ initials.append(sub_initials)
93
+ finals.append(sub_finals)
94
+
95
+ # assert len(sub_initials) == len(sub_finals) == len(word)
96
+ initials = sum(initials, [])
97
+ finals = sum(finals, [])
98
+ #
99
+ for c, v in zip(initials, finals):
100
+ raw_pinyin = c + v
101
+ # NOTE: post process for pypinyin outputs
102
+ # we discriminate i, ii and iii
103
+ if c == v:
104
+ assert c in punctuation
105
+ phone = [c]
106
+ word2ph.append(1)
107
+ else:
108
+ v_without_tone = v[:-1]
109
+ tone = v[-1]
110
+
111
+ pinyin = c + v_without_tone
112
+ assert tone in "12345"
113
+
114
+ if c:
115
+ # 多音节
116
+ v_rep_map = {
117
+ "uei": "ui",
118
+ "iou": "iu",
119
+ "uen": "un",
120
+ }
121
+ if v_without_tone in v_rep_map.keys():
122
+ pinyin = c + v_rep_map[v_without_tone]
123
+ else:
124
+ # 单音节
125
+ pinyin_rep_map = {
126
+ "ing": "ying",
127
+ "i": "yi",
128
+ "in": "yin",
129
+ "u": "wu",
130
+ }
131
+ if pinyin in pinyin_rep_map.keys():
132
+ pinyin = pinyin_rep_map[pinyin]
133
+ else:
134
+ single_rep_map = {
135
+ "v": "yu",
136
+ "e": "e",
137
+ "i": "y",
138
+ "u": "w",
139
+ }
140
+ if pinyin[0] in single_rep_map.keys():
141
+ pinyin = single_rep_map[pinyin[0]] + pinyin[1:]
142
+
143
+ assert pinyin in pinyin_to_symbol_map.keys(), (pinyin, seg, raw_pinyin)
144
+ new_c, new_v = pinyin_to_symbol_map[pinyin].split(" ")
145
+ new_v = new_v + tone
146
+ phone = [new_c, new_v]
147
+ word2ph.append(len(phone))
148
+
149
+ phones_list += phone
150
+ return phones_list, word2ph
151
+
152
+
153
+ def text_normalize(text):
154
+ # https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/paddlespeech/t2s/frontend/zh_normalization
155
+ tx = TextNormalizer()
156
+ sentences = tx.normalize(text)
157
+ dest_text = ""
158
+ for sentence in sentences:
159
+ dest_text += replace_punctuation(sentence)
160
+ return dest_text
161
+
162
+
163
+ if __name__ == "__main__":
164
+ text = "啊——但是《原神》是由,米哈\游自主,研发的一款全.新开放世界.冒险游戏"
165
+ text = "呣呣呣~就是…大人的鼹鼠党吧?"
166
+ text = "你好"
167
+ text = text_normalize(text)
168
+ print(g2p(text))
169
+
170
+
171
+ # # 示例用法
172
+ # text = "这是一个示例文本:,你好!这是一个测试..."
173
+ # print(g2p_paddle(text)) # 输出: 这是一个示例文本你好这是一个测试
SongBloom/g2p/cn_zh_g2p/cmudict-fast.rep ADDED
The diff for this file is too large to render. See raw diff
 
SongBloom/g2p/cn_zh_g2p/cmudict.rep ADDED
The diff for this file is too large to render. See raw diff
 
SongBloom/g2p/cn_zh_g2p/engdict-hot.rep ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ CHATGPT CH AE1 T JH IY1 P IY1 T IY1
2
+ JSON JH EY1 S AH0 N
SongBloom/g2p/cn_zh_g2p/engdict_cache.pickle ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9bff9393f4b192d873a11335efc8f124771087b6dc847d34fd240c2846889d2b
3
+ size 5965909
SongBloom/g2p/cn_zh_g2p/english.py ADDED
@@ -0,0 +1,369 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+ import os
3
+ import re
4
+ import wordsegment
5
+ from g2p_en import G2p
6
+
7
+ from string import punctuation
8
+
9
+ from .symbols import symbols
10
+
11
+ import unicodedata
12
+ from builtins import str as unicode
13
+ from g2p_en.expand import normalize_numbers
14
+
15
+ # Set NLTK data path programmatically to avoid needing set_env.sh
16
+ import nltk
17
+ current_file_path = os.path.dirname(__file__)
18
+ nltk_data_path = os.path.join(current_file_path, "nltk_data")
19
+ if os.path.exists(nltk_data_path):
20
+ nltk.data.path.insert(0, nltk_data_path)
21
+
22
+ from nltk.tokenize import TweetTokenizer
23
+ word_tokenize = TweetTokenizer().tokenize
24
+ from nltk import pos_tag
25
+
26
+ CMU_DICT_PATH = os.path.join(current_file_path, "cmudict.rep")
27
+ CMU_DICT_FAST_PATH = os.path.join(current_file_path, "cmudict-fast.rep")
28
+ CMU_DICT_HOT_PATH = os.path.join(current_file_path, "engdict-hot.rep")
29
+ CACHE_PATH = os.path.join(current_file_path, "engdict_cache.pickle")
30
+ NAMECACHE_PATH = os.path.join(current_file_path, "namedict_cache.pickle")
31
+
32
+ arpa = {
33
+ "AH0",
34
+ "S",
35
+ "AH1",
36
+ "EY2",
37
+ "AE2",
38
+ "EH0",
39
+ "OW2",
40
+ "UH0",
41
+ "NG",
42
+ "B",
43
+ "G",
44
+ "AY0",
45
+ "M",
46
+ "AA0",
47
+ "F",
48
+ "AO0",
49
+ "ER2",
50
+ "UH1",
51
+ "IY1",
52
+ "AH2",
53
+ "DH",
54
+ "IY0",
55
+ "EY1",
56
+ "IH0",
57
+ "K",
58
+ "N",
59
+ "W",
60
+ "IY2",
61
+ "T",
62
+ "AA1",
63
+ "ER1",
64
+ "EH2",
65
+ "OY0",
66
+ "UH2",
67
+ "UW1",
68
+ "Z",
69
+ "AW2",
70
+ "AW1",
71
+ "V",
72
+ "UW2",
73
+ "AA2",
74
+ "ER",
75
+ "AW0",
76
+ "UW0",
77
+ "R",
78
+ "OW1",
79
+ "EH1",
80
+ "ZH",
81
+ "AE0",
82
+ "IH2",
83
+ "IH",
84
+ "Y",
85
+ "JH",
86
+ "P",
87
+ "AY1",
88
+ "EY0",
89
+ "OY2",
90
+ "TH",
91
+ "HH",
92
+ "D",
93
+ "ER0",
94
+ "CH",
95
+ "AO1",
96
+ "AE1",
97
+ "AO2",
98
+ "OY1",
99
+ "AY2",
100
+ "IH1",
101
+ "OW0",
102
+ "L",
103
+ "SH",
104
+ }
105
+
106
+
107
+ def replace_phs(phs):
108
+ rep_map = {"'": "-"}
109
+ phs_new = []
110
+ for ph in phs:
111
+ if ph in symbols:
112
+ phs_new.append(ph)
113
+ elif ph in rep_map.keys():
114
+ phs_new.append(rep_map[ph])
115
+ else:
116
+ print("ph not in symbols: ", ph)
117
+ return phs_new
118
+
119
+
120
+ def read_dict():
121
+ g2p_dict = {}
122
+ start_line = 49
123
+ with open(CMU_DICT_PATH) as f:
124
+ line = f.readline()
125
+ line_index = 1
126
+ while line:
127
+ if line_index >= start_line:
128
+ line = line.strip()
129
+ word_split = line.split(" ")
130
+ word = word_split[0].lower()
131
+
132
+ syllable_split = word_split[1].split(" - ")
133
+ g2p_dict[word] = []
134
+ for syllable in syllable_split:
135
+ phone_split = syllable.split(" ")
136
+ g2p_dict[word].append(phone_split)
137
+
138
+ line_index = line_index + 1
139
+ line = f.readline()
140
+
141
+ return g2p_dict
142
+
143
+
144
+ def read_dict_new():
145
+ g2p_dict = {}
146
+ with open(CMU_DICT_PATH) as f:
147
+ line = f.readline()
148
+ line_index = 1
149
+ while line:
150
+ if line_index >= 57:
151
+ line = line.strip()
152
+ word_split = line.split(" ")
153
+ word = word_split[0].lower()
154
+ g2p_dict[word] = [word_split[1].split(" ")]
155
+
156
+ line_index = line_index + 1
157
+ line = f.readline()
158
+
159
+ with open(CMU_DICT_FAST_PATH) as f:
160
+ line = f.readline()
161
+ line_index = 1
162
+ while line:
163
+ if line_index >= 0:
164
+ line = line.strip()
165
+ word_split = line.split(" ")
166
+ word = word_split[0].lower()
167
+ if word not in g2p_dict:
168
+ g2p_dict[word] = [word_split[1:]]
169
+
170
+ line_index = line_index + 1
171
+ line = f.readline()
172
+
173
+ return g2p_dict
174
+
175
+ def hot_reload_hot(g2p_dict):
176
+ with open(CMU_DICT_HOT_PATH) as f:
177
+ line = f.readline()
178
+ line_index = 1
179
+ while line:
180
+ if line_index >= 0:
181
+ line = line.strip()
182
+ word_split = line.split(" ")
183
+ word = word_split[0].lower()
184
+ # 自定义发音词直接覆盖字典
185
+ g2p_dict[word] = [word_split[1:]]
186
+
187
+ line_index = line_index + 1
188
+ line = f.readline()
189
+
190
+ return g2p_dict
191
+
192
+
193
+ def cache_dict(g2p_dict, file_path):
194
+ with open(file_path, "wb") as pickle_file:
195
+ pickle.dump(g2p_dict, pickle_file)
196
+
197
+
198
+ def get_dict():
199
+ if os.path.exists(CACHE_PATH):
200
+ with open(CACHE_PATH, "rb") as pickle_file:
201
+ g2p_dict = pickle.load(pickle_file)
202
+ else:
203
+ g2p_dict = read_dict_new()
204
+ cache_dict(g2p_dict, CACHE_PATH)
205
+
206
+ g2p_dict = hot_reload_hot(g2p_dict)
207
+
208
+ return g2p_dict
209
+
210
+
211
+ def get_namedict():
212
+ if os.path.exists(NAMECACHE_PATH):
213
+ with open(NAMECACHE_PATH, "rb") as pickle_file:
214
+ name_dict = pickle.load(pickle_file)
215
+ else:
216
+ name_dict = {}
217
+
218
+ return name_dict
219
+
220
+
221
+ def text_normalize(text):
222
+ # todo: eng text normalize
223
+ # 适配中文及 g2p_en 标点
224
+ rep_map = {
225
+ "[;::,;]": ",",
226
+ '["’]': "'",
227
+ "。": ".",
228
+ "!": "!",
229
+ "?": "?",
230
+ }
231
+ for p, r in rep_map.items():
232
+ text = re.sub(p, r, text)
233
+
234
+ # 来自 g2p_en 文本格式化处理
235
+ # 增加大写兼容
236
+ text = unicode(text)
237
+ text = normalize_numbers(text)
238
+ text = ''.join(char for char in unicodedata.normalize('NFD', text)
239
+ if unicodedata.category(char) != 'Mn') # Strip accents
240
+ text = re.sub("[^ A-Za-z'.,?!\-]", "", text)
241
+ text = re.sub(r"(?i)i\.e\.", "that is", text)
242
+ text = re.sub(r"(?i)e\.g\.", "for example", text)
243
+
244
+ return text
245
+
246
+
247
+ class en_G2p(G2p):
248
+ def __init__(self):
249
+ super().__init__()
250
+ # 分词初始化
251
+ wordsegment.load()
252
+
253
+ # 扩展过时字典, 添加姓名字典
254
+ self.cmu = get_dict()
255
+ self.namedict = get_namedict()
256
+
257
+ # 剔除读音错误的几个缩写
258
+ for word in ["AE", "AI", "AR", "IOS", "HUD", "OS"]:
259
+ del self.cmu[word.lower()]
260
+
261
+ # 修正多音字
262
+ self.homograph2features["read"] = (['R', 'IY1', 'D'], ['R', 'EH1', 'D'], 'VBP')
263
+ self.homograph2features["complex"] = (['K', 'AH0', 'M', 'P', 'L', 'EH1', 'K', 'S'], ['K', 'AA1', 'M', 'P', 'L', 'EH0', 'K', 'S'], 'JJ')
264
+
265
+
266
+ def __call__(self, text):
267
+ # tokenization
268
+ words = word_tokenize(text)
269
+ tokens = pos_tag(words) # tuples of (word, tag)
270
+
271
+ # steps
272
+ prons = []
273
+ for o_word, pos in tokens:
274
+ # 还原 g2p_en 小写操作逻辑
275
+ word = o_word.lower()
276
+
277
+ if re.search("[a-z]", word) is None:
278
+ pron = [word]
279
+ # 先把单字母推出去
280
+ elif len(word) == 1:
281
+ # 单读 A 发音修正, 这里需要原格式 o_word 判断大写
282
+ if o_word == "A":
283
+ pron = ['EY1']
284
+ else:
285
+ pron = self.cmu[word][0]
286
+ # g2p_en 原版多音字处理
287
+ elif word in self.homograph2features: # Check homograph
288
+ pron1, pron2, pos1 = self.homograph2features[word]
289
+ if pos.startswith(pos1):
290
+ pron = pron1
291
+ # pos1比pos长仅出现在read
292
+ elif len(pos) < len(pos1) and pos == pos1[:len(pos)]:
293
+ pron = pron1
294
+ else:
295
+ pron = pron2
296
+ else:
297
+ # 递归查找预测
298
+ pron = self.qryword(o_word)
299
+
300
+ prons.extend(pron)
301
+ prons.extend([" "])
302
+
303
+ return prons[:-1]
304
+
305
+
306
+ def qryword(self, o_word):
307
+ word = o_word.lower()
308
+
309
+ # 查字典, 单字母除外
310
+ if len(word) > 1 and word in self.cmu: # lookup CMU dict
311
+ return self.cmu[word][0]
312
+
313
+ # 单词仅首字母大写时查找姓名字典
314
+ if o_word.istitle() and word in self.namedict:
315
+ return self.namedict[word][0]
316
+
317
+ # oov 长度小于等于 3 直接读字母
318
+ if len(word) <= 3:
319
+ phones = []
320
+ for w in word:
321
+ # 单读 A 发音修正, 此处不存在大写的情况
322
+ if w == "a":
323
+ phones.extend(['EY1'])
324
+ else:
325
+ phones.extend(self.cmu[w][0])
326
+ return phones
327
+
328
+ # 尝试分离所有格
329
+ if re.match(r"^([a-z]+)('s)$", word):
330
+ phones = self.qryword(word[:-2])[:]
331
+ # P T K F TH HH 无声辅音结尾 's 发 ['S']
332
+ if phones[-1] in ['P', 'T', 'K', 'F', 'TH', 'HH']:
333
+ phones.extend(['S'])
334
+ # S Z SH ZH CH JH 擦声结尾 's 发 ['IH1', 'Z'] 或 ['AH0', 'Z']
335
+ elif phones[-1] in ['S', 'Z', 'SH', 'ZH', 'CH', 'JH']:
336
+ phones.extend(['AH0', 'Z'])
337
+ # B D G DH V M N NG L R W Y 有声辅音结尾 's 发 ['Z']
338
+ # AH0 AH1 AH2 EY0 EY1 EY2 AE0 AE1 AE2 EH0 EH1 EH2 OW0 OW1 OW2 UH0 UH1 UH2 IY0 IY1 IY2 AA0 AA1 AA2 AO0 AO1 AO2
339
+ # ER ER0 ER1 ER2 UW0 UW1 UW2 AY0 AY1 AY2 AW0 AW1 AW2 OY0 OY1 OY2 IH IH0 IH1 IH2 元音结尾 's 发 ['Z']
340
+ else:
341
+ phones.extend(['Z'])
342
+ return phones
343
+
344
+ # 尝试进行分词,应对复合词
345
+ comps = wordsegment.segment(word.lower())
346
+
347
+ # 无法分词的送回去预测
348
+ if len(comps)==1:
349
+ return self.predict(word)
350
+
351
+ # 可以分词的递归处理
352
+ return [phone for comp in comps for phone in self.qryword(comp)]
353
+
354
+
355
+ _g2p = en_G2p()
356
+
357
+
358
+ def g2p(text):
359
+ # g2p_en 整段推理,剔除不存在的arpa返回
360
+ phone_list = _g2p(text)
361
+ phones = [ph if ph != "<unk>" else "UNK" for ph in phone_list if ph not in [" ", "<pad>", "UW", "</s>", "<s>"]]
362
+
363
+ return replace_phs(phones)
364
+
365
+
366
+ if __name__ == "__main__":
367
+ print(g2p("hello"))
368
+ print(g2p(text_normalize("e.g. I used openai's AI tool to draw a picture.")))
369
+ print(g2p(text_normalize("In this; paper, we propose 1 DSPGAN, a GAN-based universal vocoder.")))
SongBloom/g2p/cn_zh_g2p/nltk_data/corpora/cmudict.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d07cca47fd72ad32ea9d8ad1219f85301eeaf4568f8b6b73747506a71fb5afd6
3
+ size 896069
SongBloom/g2p/cn_zh_g2p/nltk_data/corpora/cmudict/README ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ The Carnegie Mellon Pronouncing Dictionary [cmudict.0.7a]
2
+
3
+ ftp://ftp.cs.cmu.edu/project/speech/dict/
4
+ https://cmusphinx.svn.sourceforge.net/svnroot/cmusphinx/trunk/cmudict/cmudict.0.7a
5
+
6
+ Copyright (C) 1993-2008 Carnegie Mellon University. All rights reserved.
7
+
8
+ File Format: Each line consists of an uppercased word,
9
+ a counter (for alternative pronunciations), and a transcription.
10
+ Vowels are marked for stress (1=primary, 2=secondary, 0=no stress).
11
+ E.g.: NATURAL 1 N AE1 CH ER0 AH0 L
12
+
13
+ The dictionary contains 127069 entries. Of these, 119400 words are assigned
14
+ a unique pronunciation, 6830 words have two pronunciations, and 839 words have
15
+ three or more pronunciations. Many of these are fast-speech variants.
16
+
17
+ Phonemes: There are 39 phonemes, as shown below:
18
+
19
+ Phoneme Example Translation Phoneme Example Translation
20
+ ------- ------- ----------- ------- ------- -----------
21
+ AA odd AA D AE at AE T
22
+ AH hut HH AH T AO ought AO T
23
+ AW cow K AW AY hide HH AY D
24
+ B be B IY CH cheese CH IY Z
25
+ D dee D IY DH thee DH IY
26
+ EH Ed EH D ER hurt HH ER T
27
+ EY ate EY T F fee F IY
28
+ G green G R IY N HH he HH IY
29
+ IH it IH T IY eat IY T
30
+ JH gee JH IY K key K IY
31
+ L lee L IY M me M IY
32
+ N knee N IY NG ping P IH NG
33
+ OW oat OW T OY toy T OY
34
+ P pee P IY R read R IY D
35
+ S sea S IY SH she SH IY
36
+ T tea T IY TH theta TH EY T AH
37
+ UH hood HH UH D UW two T UW
38
+ V vee V IY W we W IY
39
+ Y yield Y IY L D Z zee Z IY
40
+ ZH seizure S IY ZH ER
41
+
42
+ (For NLTK, entries have been sorted so that, e.g. FIRE 1 and FIRE 2
43
+ are contiguous, and not separated by FIRE'S 1.)
44
+
45
+ Redistribution and use in source and binary forms, with or without
46
+ modification, are permitted provided that the following conditions
47
+ are met:
48
+
49
+ 1. Redistributions of source code must retain the above copyright
50
+ notice, this list of conditions and the following disclaimer.
51
+ The contents of this file are deemed to be source code.
52
+
53
+ 2. Redistributions in binary form must reproduce the above copyright
54
+ notice, this list of conditions and the following disclaimer in
55
+ the documentation and/or other materials provided with the
56
+ distribution.
57
+
58
+ This work was supported in part by funding from the Defense Advanced
59
+ Research Projects Agency, the Office of Naval Research and the National
60
+ Science Foundation of the United States of America, and by member
61
+ companies of the Carnegie Mellon Sphinx Speech Consortium. We acknowledge
62
+ the contributions of many volunteers to the expansion and improvement of
63
+ this dictionary.
64
+
65
+ THIS SOFTWARE IS PROVIDED BY CARNEGIE MELLON UNIVERSITY ``AS IS'' AND
66
+ ANY EXPRESSED OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO,
67
+ THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
68
+ PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL CARNEGIE MELLON UNIVERSITY
69
+ NOR ITS EMPLOYEES BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
70
+ SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
71
+ LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
72
+ DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
73
+ THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
74
+ (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
75
+ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
76
+
SongBloom/g2p/cn_zh_g2p/nltk_data/corpora/cmudict/cmudict ADDED
The diff for this file is too large to render. See raw diff
 
SongBloom/g2p/cn_zh_g2p/nltk_data/taggers/averaged_perceptron_tagger.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e1f13cf2532daadfd6f3bc481a49859f0b8ea6432ccdcd83e6a49a5f19008de9
3
+ size 2526731
SongBloom/g2p/cn_zh_g2p/nltk_data/taggers/averaged_perceptron_tagger/averaged_perceptron_tagger.pickle ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:25a5a19c7ced7b2bac3831da5bc0afcc2c34e5dd01cd4f361bb799949a696238
3
+ size 6138625
SongBloom/g2p/cn_zh_g2p/opencpop-strict.txt ADDED
@@ -0,0 +1,429 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ a AA a
2
+ ai AA ai
3
+ an AA an
4
+ ang AA ang
5
+ ao AA ao
6
+ ba b a
7
+ bai b ai
8
+ ban b an
9
+ bang b ang
10
+ bao b ao
11
+ bei b ei
12
+ ben b en
13
+ beng b eng
14
+ bi b i
15
+ bian b ian
16
+ biao b iao
17
+ bie b ie
18
+ bin b in
19
+ bing b ing
20
+ bo b o
21
+ bu b u
22
+ ca c a
23
+ cai c ai
24
+ can c an
25
+ cang c ang
26
+ cao c ao
27
+ ce c e
28
+ cei c ei
29
+ cen c en
30
+ ceng c eng
31
+ cha ch a
32
+ chai ch ai
33
+ chan ch an
34
+ chang ch ang
35
+ chao ch ao
36
+ che ch e
37
+ chen ch en
38
+ cheng ch eng
39
+ chi ch ir
40
+ chong ch ong
41
+ chou ch ou
42
+ chu ch u
43
+ chua ch ua
44
+ chuai ch uai
45
+ chuan ch uan
46
+ chuang ch uang
47
+ chui ch ui
48
+ chun ch un
49
+ chuo ch uo
50
+ ci c i0
51
+ cong c ong
52
+ cou c ou
53
+ cu c u
54
+ cuan c uan
55
+ cui c ui
56
+ cun c un
57
+ cuo c uo
58
+ da d a
59
+ dai d ai
60
+ dan d an
61
+ dang d ang
62
+ dao d ao
63
+ de d e
64
+ dei d ei
65
+ den d en
66
+ deng d eng
67
+ di d i
68
+ dia d ia
69
+ dian d ian
70
+ diao d iao
71
+ die d ie
72
+ ding d ing
73
+ diu d iu
74
+ dong d ong
75
+ dou d ou
76
+ du d u
77
+ duan d uan
78
+ dui d ui
79
+ dun d un
80
+ duo d uo
81
+ e EE e
82
+ ei EE ei
83
+ en EE en
84
+ eng EE eng
85
+ er EE er
86
+ fa f a
87
+ fan f an
88
+ fang f ang
89
+ fei f ei
90
+ fen f en
91
+ feng f eng
92
+ fo f o
93
+ fou f ou
94
+ fu f u
95
+ ga g a
96
+ gai g ai
97
+ gan g an
98
+ gang g ang
99
+ gao g ao
100
+ ge g e
101
+ gei g ei
102
+ gen g en
103
+ geng g eng
104
+ gong g ong
105
+ gou g ou
106
+ gu g u
107
+ gua g ua
108
+ guai g uai
109
+ guan g uan
110
+ guang g uang
111
+ gui g ui
112
+ gun g un
113
+ guo g uo
114
+ ha h a
115
+ hai h ai
116
+ han h an
117
+ hang h ang
118
+ hao h ao
119
+ he h e
120
+ hei h ei
121
+ hen h en
122
+ heng h eng
123
+ hong h ong
124
+ hou h ou
125
+ hu h u
126
+ hua h ua
127
+ huai h uai
128
+ huan h uan
129
+ huang h uang
130
+ hui h ui
131
+ hun h un
132
+ huo h uo
133
+ ji j i
134
+ jia j ia
135
+ jian j ian
136
+ jiang j iang
137
+ jiao j iao
138
+ jie j ie
139
+ jin j in
140
+ jing j ing
141
+ jiong j iong
142
+ jiu j iu
143
+ ju j v
144
+ jv j v
145
+ juan j van
146
+ jvan j van
147
+ jue j ve
148
+ jve j ve
149
+ jun j vn
150
+ jvn j vn
151
+ ka k a
152
+ kai k ai
153
+ kan k an
154
+ kang k ang
155
+ kao k ao
156
+ ke k e
157
+ kei k ei
158
+ ken k en
159
+ keng k eng
160
+ kong k ong
161
+ kou k ou
162
+ ku k u
163
+ kua k ua
164
+ kuai k uai
165
+ kuan k uan
166
+ kuang k uang
167
+ kui k ui
168
+ kun k un
169
+ kuo k uo
170
+ la l a
171
+ lai l ai
172
+ lan l an
173
+ lang l ang
174
+ lao l ao
175
+ le l e
176
+ lei l ei
177
+ leng l eng
178
+ li l i
179
+ lia l ia
180
+ lian l ian
181
+ liang l iang
182
+ liao l iao
183
+ lie l ie
184
+ lin l in
185
+ ling l ing
186
+ liu l iu
187
+ lo l o
188
+ long l ong
189
+ lou l ou
190
+ lu l u
191
+ luan l uan
192
+ lun l un
193
+ luo l uo
194
+ lv l v
195
+ lve l ve
196
+ ma m a
197
+ mai m ai
198
+ man m an
199
+ mang m ang
200
+ mao m ao
201
+ me m e
202
+ mei m ei
203
+ men m en
204
+ meng m eng
205
+ mi m i
206
+ mian m ian
207
+ miao m iao
208
+ mie m ie
209
+ min m in
210
+ ming m ing
211
+ miu m iu
212
+ mo m o
213
+ mou m ou
214
+ mu m u
215
+ na n a
216
+ nai n ai
217
+ nan n an
218
+ nang n ang
219
+ nao n ao
220
+ ne n e
221
+ nei n ei
222
+ nen n en
223
+ neng n eng
224
+ ni n i
225
+ nian n ian
226
+ niang n iang
227
+ niao n iao
228
+ nie n ie
229
+ nin n in
230
+ ning n ing
231
+ niu n iu
232
+ nong n ong
233
+ nou n ou
234
+ nu n u
235
+ nuan n uan
236
+ nun n un
237
+ nuo n uo
238
+ nv n v
239
+ nve n ve
240
+ o OO o
241
+ ou OO ou
242
+ pa p a
243
+ pai p ai
244
+ pan p an
245
+ pang p ang
246
+ pao p ao
247
+ pei p ei
248
+ pen p en
249
+ peng p eng
250
+ pi p i
251
+ pian p ian
252
+ piao p iao
253
+ pie p ie
254
+ pin p in
255
+ ping p ing
256
+ po p o
257
+ pou p ou
258
+ pu p u
259
+ qi q i
260
+ qia q ia
261
+ qian q ian
262
+ qiang q iang
263
+ qiao q iao
264
+ qie q ie
265
+ qin q in
266
+ qing q ing
267
+ qiong q iong
268
+ qiu q iu
269
+ qu q v
270
+ qv q v
271
+ quan q van
272
+ qvan q van
273
+ que q ve
274
+ qve q ve
275
+ qun q vn
276
+ qvn q vn
277
+ ran r an
278
+ rang r ang
279
+ rao r ao
280
+ re r e
281
+ ren r en
282
+ reng r eng
283
+ ri r ir
284
+ rong r ong
285
+ rou r ou
286
+ ru r u
287
+ rua r ua
288
+ ruan r uan
289
+ rui r ui
290
+ run r un
291
+ ruo r uo
292
+ sa s a
293
+ sai s ai
294
+ san s an
295
+ sang s ang
296
+ sao s ao
297
+ se s e
298
+ sen s en
299
+ seng s eng
300
+ sha sh a
301
+ shai sh ai
302
+ shan sh an
303
+ shang sh ang
304
+ shao sh ao
305
+ she sh e
306
+ shei sh ei
307
+ shen sh en
308
+ sheng sh eng
309
+ shi sh ir
310
+ shou sh ou
311
+ shu sh u
312
+ shua sh ua
313
+ shuai sh uai
314
+ shuan sh uan
315
+ shuang sh uang
316
+ shui sh ui
317
+ shun sh un
318
+ shuo sh uo
319
+ si s i0
320
+ song s ong
321
+ sou s ou
322
+ su s u
323
+ suan s uan
324
+ sui s ui
325
+ sun s un
326
+ suo s uo
327
+ ta t a
328
+ tai t ai
329
+ tan t an
330
+ tang t ang
331
+ tao t ao
332
+ te t e
333
+ tei t ei
334
+ teng t eng
335
+ ti t i
336
+ tian t ian
337
+ tiao t iao
338
+ tie t ie
339
+ ting t ing
340
+ tong t ong
341
+ tou t ou
342
+ tu t u
343
+ tuan t uan
344
+ tui t ui
345
+ tun t un
346
+ tuo t uo
347
+ wa w a
348
+ wai w ai
349
+ wan w an
350
+ wang w ang
351
+ wei w ei
352
+ wen w en
353
+ weng w eng
354
+ wo w o
355
+ wu w u
356
+ xi x i
357
+ xia x ia
358
+ xian x ian
359
+ xiang x iang
360
+ xiao x iao
361
+ xie x ie
362
+ xin x in
363
+ xing x ing
364
+ xiong x iong
365
+ xiu x iu
366
+ xu x v
367
+ xv x v
368
+ xuan x van
369
+ xvan x van
370
+ xue x ve
371
+ xve x ve
372
+ xun x vn
373
+ xvn x vn
374
+ ya y a
375
+ yan y En
376
+ yang y ang
377
+ yao y ao
378
+ ye y E
379
+ yi y i
380
+ yin y in
381
+ ying y ing
382
+ yo y o
383
+ yong y ong
384
+ you y ou
385
+ yu y v
386
+ yv y v
387
+ yuan y van
388
+ yvan y van
389
+ yue y ve
390
+ yve y ve
391
+ yun y vn
392
+ yvn y vn
393
+ za z a
394
+ zai z ai
395
+ zan z an
396
+ zang z ang
397
+ zao z ao
398
+ ze z e
399
+ zei z ei
400
+ zen z en
401
+ zeng z eng
402
+ zha zh a
403
+ zhai zh ai
404
+ zhan zh an
405
+ zhang zh ang
406
+ zhao zh ao
407
+ zhe zh e
408
+ zhei zh ei
409
+ zhen zh en
410
+ zheng zh eng
411
+ zhi zh ir
412
+ zhong zh ong
413
+ zhou zh ou
414
+ zhu zh u
415
+ zhua zh ua
416
+ zhuai zh uai
417
+ zhuan zh uan
418
+ zhuang zh uang
419
+ zhui zh ui
420
+ zhun zh un
421
+ zhuo zh uo
422
+ zi z i0
423
+ zong z ong
424
+ zou z ou
425
+ zu z u
426
+ zuan z uan
427
+ zui z ui
428
+ zun z un
429
+ zuo z uo
SongBloom/g2p/cn_zh_g2p/symbols.py ADDED
@@ -0,0 +1,401 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ # punctuation = ['!', '?', '…', ",", ".","@"]#@是SP停顿
4
+ punctuation = ["!", "?", "…", ",", "."] # @是SP停顿
5
+ punctuation.append("-")
6
+ pu_symbols = punctuation + ["SP", "SP2", "SP3", "UNK"]
7
+ # pu_symbols = punctuation + ["SP", 'SP2', 'SP3','SP4', "UNK"]
8
+ pad = "_"
9
+
10
+ c = [
11
+ "AA",
12
+ "EE",
13
+ "OO",
14
+ "b",
15
+ "c",
16
+ "ch",
17
+ "d",
18
+ "f",
19
+ "g",
20
+ "h",
21
+ "j",
22
+ "k",
23
+ "l",
24
+ "m",
25
+ "n",
26
+ "p",
27
+ "q",
28
+ "r",
29
+ "s",
30
+ "sh",
31
+ "t",
32
+ "w",
33
+ "x",
34
+ "y",
35
+ "z",
36
+ "zh",
37
+ ]
38
+ v = [
39
+ "E1",
40
+ "En1",
41
+ "a1",
42
+ "ai1",
43
+ "an1",
44
+ "ang1",
45
+ "ao1",
46
+ "e1",
47
+ "ei1",
48
+ "en1",
49
+ "eng1",
50
+ "er1",
51
+ "i1",
52
+ "i01",
53
+ "ia1",
54
+ "ian1",
55
+ "iang1",
56
+ "iao1",
57
+ "ie1",
58
+ "in1",
59
+ "ing1",
60
+ "iong1",
61
+ "ir1",
62
+ "iu1",
63
+ "o1",
64
+ "ong1",
65
+ "ou1",
66
+ "u1",
67
+ "ua1",
68
+ "uai1",
69
+ "uan1",
70
+ "uang1",
71
+ "ui1",
72
+ "un1",
73
+ "uo1",
74
+ "v1",
75
+ "van1",
76
+ "ve1",
77
+ "vn1",
78
+ "E2",
79
+ "En2",
80
+ "a2",
81
+ "ai2",
82
+ "an2",
83
+ "ang2",
84
+ "ao2",
85
+ "e2",
86
+ "ei2",
87
+ "en2",
88
+ "eng2",
89
+ "er2",
90
+ "i2",
91
+ "i02",
92
+ "ia2",
93
+ "ian2",
94
+ "iang2",
95
+ "iao2",
96
+ "ie2",
97
+ "in2",
98
+ "ing2",
99
+ "iong2",
100
+ "ir2",
101
+ "iu2",
102
+ "o2",
103
+ "ong2",
104
+ "ou2",
105
+ "u2",
106
+ "ua2",
107
+ "uai2",
108
+ "uan2",
109
+ "uang2",
110
+ "ui2",
111
+ "un2",
112
+ "uo2",
113
+ "v2",
114
+ "van2",
115
+ "ve2",
116
+ "vn2",
117
+ "E3",
118
+ "En3",
119
+ "a3",
120
+ "ai3",
121
+ "an3",
122
+ "ang3",
123
+ "ao3",
124
+ "e3",
125
+ "ei3",
126
+ "en3",
127
+ "eng3",
128
+ "er3",
129
+ "i3",
130
+ "i03",
131
+ "ia3",
132
+ "ian3",
133
+ "iang3",
134
+ "iao3",
135
+ "ie3",
136
+ "in3",
137
+ "ing3",
138
+ "iong3",
139
+ "ir3",
140
+ "iu3",
141
+ "o3",
142
+ "ong3",
143
+ "ou3",
144
+ "u3",
145
+ "ua3",
146
+ "uai3",
147
+ "uan3",
148
+ "uang3",
149
+ "ui3",
150
+ "un3",
151
+ "uo3",
152
+ "v3",
153
+ "van3",
154
+ "ve3",
155
+ "vn3",
156
+ "E4",
157
+ "En4",
158
+ "a4",
159
+ "ai4",
160
+ "an4",
161
+ "ang4",
162
+ "ao4",
163
+ "e4",
164
+ "ei4",
165
+ "en4",
166
+ "eng4",
167
+ "er4",
168
+ "i4",
169
+ "i04",
170
+ "ia4",
171
+ "ian4",
172
+ "iang4",
173
+ "iao4",
174
+ "ie4",
175
+ "in4",
176
+ "ing4",
177
+ "iong4",
178
+ "ir4",
179
+ "iu4",
180
+ "o4",
181
+ "ong4",
182
+ "ou4",
183
+ "u4",
184
+ "ua4",
185
+ "uai4",
186
+ "uan4",
187
+ "uang4",
188
+ "ui4",
189
+ "un4",
190
+ "uo4",
191
+ "v4",
192
+ "van4",
193
+ "ve4",
194
+ "vn4",
195
+ "E5",
196
+ "En5",
197
+ "a5",
198
+ "ai5",
199
+ "an5",
200
+ "ang5",
201
+ "ao5",
202
+ "e5",
203
+ "ei5",
204
+ "en5",
205
+ "eng5",
206
+ "er5",
207
+ "i5",
208
+ "i05",
209
+ "ia5",
210
+ "ian5",
211
+ "iang5",
212
+ "iao5",
213
+ "ie5",
214
+ "in5",
215
+ "ing5",
216
+ "iong5",
217
+ "ir5",
218
+ "iu5",
219
+ "o5",
220
+ "ong5",
221
+ "ou5",
222
+ "u5",
223
+ "ua5",
224
+ "uai5",
225
+ "uan5",
226
+ "uang5",
227
+ "ui5",
228
+ "un5",
229
+ "uo5",
230
+ "v5",
231
+ "van5",
232
+ "ve5",
233
+ "vn5",
234
+ ]
235
+
236
+ v_without_tone = [
237
+ "E",
238
+ "En",
239
+ "a",
240
+ "ai",
241
+ "an",
242
+ "ang",
243
+ "ao",
244
+ "e",
245
+ "ei",
246
+ "en",
247
+ "eng",
248
+ "er",
249
+ "i",
250
+ "i0",
251
+ "ia",
252
+ "ian",
253
+ "iang",
254
+ "iao",
255
+ "ie",
256
+ "in",
257
+ "ing",
258
+ "iong",
259
+ "ir",
260
+ "iu",
261
+ "o",
262
+ "ong",
263
+ "ou",
264
+ "u",
265
+ "ua",
266
+ "uai",
267
+ "uan",
268
+ "uang",
269
+ "ui",
270
+ "un",
271
+ "uo",
272
+ "v",
273
+ "van",
274
+ "ve",
275
+ "vn",
276
+ ]
277
+
278
+ # japanese
279
+ ja_symbols = [
280
+ "I",
281
+ "N",
282
+ "U",
283
+ "a",
284
+ "b",
285
+ "by",
286
+ "ch",
287
+ "cl",
288
+ "d",
289
+ "dy",
290
+ "e",
291
+ "f",
292
+ "g",
293
+ "gy",
294
+ "h",
295
+ "hy",
296
+ "i",
297
+ "j",
298
+ "k",
299
+ "ky",
300
+ "m",
301
+ "my",
302
+ "n",
303
+ "ny",
304
+ "o",
305
+ "p",
306
+ "py",
307
+ "r",
308
+ "ry",
309
+ "s",
310
+ "sh",
311
+ "t",
312
+ "ts",
313
+ "u",
314
+ "v",
315
+ "w",
316
+ "y",
317
+ "z",
318
+ # "[", #上升调型
319
+ # "]", #下降调型
320
+ # "$", #结束符
321
+ # "^", #开始符
322
+ ]
323
+
324
+ arpa = {
325
+ "AH0",
326
+ "S",
327
+ "AH1",
328
+ "EY2",
329
+ "AE2",
330
+ "EH0",
331
+ "OW2",
332
+ "UH0",
333
+ "NG",
334
+ "B",
335
+ "G",
336
+ "AY0",
337
+ "M",
338
+ "AA0",
339
+ "F",
340
+ "AO0",
341
+ "ER2",
342
+ "UH1",
343
+ "IY1",
344
+ "AH2",
345
+ "DH",
346
+ "IY0",
347
+ "EY1",
348
+ "IH0",
349
+ "K",
350
+ "N",
351
+ "W",
352
+ "IY2",
353
+ "T",
354
+ "AA1",
355
+ "ER1",
356
+ "EH2",
357
+ "OY0",
358
+ "UH2",
359
+ "UW1",
360
+ "Z",
361
+ "AW2",
362
+ "AW1",
363
+ "V",
364
+ "UW2",
365
+ "AA2",
366
+ "ER",
367
+ "AW0",
368
+ "UW0",
369
+ "R",
370
+ "OW1",
371
+ "EH1",
372
+ "ZH",
373
+ "AE0",
374
+ "IH2",
375
+ "IH",
376
+ "Y",
377
+ "JH",
378
+ "P",
379
+ "AY1",
380
+ "EY0",
381
+ "OY2",
382
+ "TH",
383
+ "HH",
384
+ "D",
385
+ "ER0",
386
+ "CH",
387
+ "AO1",
388
+ "AE1",
389
+ "AO2",
390
+ "OY1",
391
+ "AY2",
392
+ "IH1",
393
+ "OW0",
394
+ "L",
395
+ "SH",
396
+ }
397
+
398
+ symbols = [pad] + c + v + ja_symbols + pu_symbols + list(arpa)
399
+ symbols = sorted(set(symbols))
400
+ if __name__ == "__main__":
401
+ print(len(symbols))
SongBloom/g2p/cn_zh_g2p/tone_sandhi.py ADDED
@@ -0,0 +1,806 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import List
15
+ from typing import Tuple
16
+
17
+ import jieba_fast as jieba
18
+ from pypinyin import lazy_pinyin
19
+ from pypinyin import Style
20
+
21
+
22
+ class ToneSandhi:
23
+ def __init__(self):
24
+ self.must_neural_tone_words = {
25
+ "麻烦",
26
+ "麻利",
27
+ "鸳鸯",
28
+ "高粱",
29
+ "骨头",
30
+ "骆驼",
31
+ "马虎",
32
+ "首饰",
33
+ "馒头",
34
+ "馄饨",
35
+ "风筝",
36
+ "难为",
37
+ "队伍",
38
+ "阔气",
39
+ "闺女",
40
+ "门道",
41
+ "锄头",
42
+ "铺盖",
43
+ "铃铛",
44
+ "铁匠",
45
+ "钥匙",
46
+ "里脊",
47
+ "里头",
48
+ "部分",
49
+ "那么",
50
+ "道士",
51
+ "造化",
52
+ "迷糊",
53
+ "连累",
54
+ "这么",
55
+ "这个",
56
+ "运气",
57
+ "过去",
58
+ "软和",
59
+ "转悠",
60
+ "踏实",
61
+ "跳蚤",
62
+ "跟头",
63
+ "趔趄",
64
+ "财主",
65
+ "豆腐",
66
+ "讲究",
67
+ "记性",
68
+ "记号",
69
+ "认识",
70
+ "规矩",
71
+ "见识",
72
+ "裁缝",
73
+ "补丁",
74
+ "衣裳",
75
+ "衣服",
76
+ "衙门",
77
+ "街坊",
78
+ "行李",
79
+ "行当",
80
+ "蛤蟆",
81
+ "蘑菇",
82
+ "薄荷",
83
+ "葫芦",
84
+ "葡萄",
85
+ "萝卜",
86
+ "荸荠",
87
+ "苗条",
88
+ "苗头",
89
+ "苍蝇",
90
+ "芝麻",
91
+ "舒服",
92
+ "舒坦",
93
+ "舌头",
94
+ "自在",
95
+ "膏药",
96
+ "脾气",
97
+ "脑袋",
98
+ "脊梁",
99
+ "能耐",
100
+ "胳膊",
101
+ "胭脂",
102
+ "胡萝",
103
+ "胡琴",
104
+ "胡同",
105
+ "聪明",
106
+ "耽误",
107
+ "耽搁",
108
+ "耷拉",
109
+ "耳朵",
110
+ "老爷",
111
+ "老实",
112
+ "老婆",
113
+ "老头",
114
+ "老太",
115
+ "翻腾",
116
+ "罗嗦",
117
+ "罐头",
118
+ "编辑",
119
+ "结实",
120
+ "红火",
121
+ "累赘",
122
+ "糨糊",
123
+ "糊涂",
124
+ "精神",
125
+ "粮食",
126
+ "簸箕",
127
+ "篱笆",
128
+ "算计",
129
+ "算盘",
130
+ "答应",
131
+ "笤帚",
132
+ "笑语",
133
+ "笑话",
134
+ "窟窿",
135
+ "窝囊",
136
+ "窗户",
137
+ "稳当",
138
+ "稀罕",
139
+ "称呼",
140
+ "秧歌",
141
+ "秀气",
142
+ "秀才",
143
+ "福气",
144
+ "祖宗",
145
+ "砚台",
146
+ "码头",
147
+ "石榴",
148
+ "石头",
149
+ "石匠",
150
+ "知识",
151
+ "眼睛",
152
+ "眯缝",
153
+ "眨巴",
154
+ "眉毛",
155
+ "相声",
156
+ "盘算",
157
+ "白净",
158
+ "痢疾",
159
+ "痛快",
160
+ "疟疾",
161
+ "疙瘩",
162
+ "疏忽",
163
+ "畜生",
164
+ "生意",
165
+ "甘蔗",
166
+ "琵琶",
167
+ "琢磨",
168
+ "琉璃",
169
+ "玻璃",
170
+ "玫瑰",
171
+ "玄乎",
172
+ "狐狸",
173
+ "状元",
174
+ "特务",
175
+ "牲口",
176
+ "牙碜",
177
+ "牌楼",
178
+ "爽快",
179
+ "爱人",
180
+ "热闹",
181
+ "烧饼",
182
+ "烟筒",
183
+ "烂糊",
184
+ "点心",
185
+ "炊帚",
186
+ "灯笼",
187
+ "火候",
188
+ "漂亮",
189
+ "滑溜",
190
+ "溜达",
191
+ "温和",
192
+ "清楚",
193
+ "消息",
194
+ "浪头",
195
+ "活泼",
196
+ "比方",
197
+ "正经",
198
+ "欺负",
199
+ "模糊",
200
+ "槟榔",
201
+ "棺材",
202
+ "棒槌",
203
+ "棉花",
204
+ "核桃",
205
+ "栅栏",
206
+ "柴火",
207
+ "架势",
208
+ "枕头",
209
+ "���杷",
210
+ "机灵",
211
+ "本事",
212
+ "木头",
213
+ "木匠",
214
+ "朋友",
215
+ "月饼",
216
+ "月亮",
217
+ "暖和",
218
+ "明白",
219
+ "时候",
220
+ "新鲜",
221
+ "故事",
222
+ "收拾",
223
+ "收成",
224
+ "提防",
225
+ "挖苦",
226
+ "挑剔",
227
+ "指甲",
228
+ "指头",
229
+ "拾掇",
230
+ "拳头",
231
+ "拨弄",
232
+ "招牌",
233
+ "招呼",
234
+ "抬举",
235
+ "护士",
236
+ "折腾",
237
+ "扫帚",
238
+ "打量",
239
+ "打算",
240
+ "打点",
241
+ "打扮",
242
+ "打听",
243
+ "打发",
244
+ "扎实",
245
+ "扁担",
246
+ "戒指",
247
+ "懒得",
248
+ "意识",
249
+ "意思",
250
+ "情形",
251
+ "悟性",
252
+ "怪物",
253
+ "思量",
254
+ "怎么",
255
+ "念头",
256
+ "念叨",
257
+ "快活",
258
+ "忙活",
259
+ "志气",
260
+ "心思",
261
+ "得罪",
262
+ "张罗",
263
+ "弟兄",
264
+ "开通",
265
+ "应酬",
266
+ "庄稼",
267
+ "干事",
268
+ "帮手",
269
+ "帐篷",
270
+ "希罕",
271
+ "师父",
272
+ "师傅",
273
+ "巴结",
274
+ "巴掌",
275
+ "差事",
276
+ "工夫",
277
+ "岁数",
278
+ "屁股",
279
+ "尾巴",
280
+ "少爷",
281
+ "小气",
282
+ "小伙",
283
+ "将就",
284
+ "对头",
285
+ "对付",
286
+ "寡妇",
287
+ "家伙",
288
+ "客气",
289
+ "实在",
290
+ "官司",
291
+ "学问",
292
+ "学生",
293
+ "字号",
294
+ "嫁妆",
295
+ "媳妇",
296
+ "媒人",
297
+ "婆家",
298
+ "娘家",
299
+ "委屈",
300
+ "姑娘",
301
+ "姐夫",
302
+ "妯娌",
303
+ "妥当",
304
+ "妖精",
305
+ "奴才",
306
+ "女婿",
307
+ "头发",
308
+ "太阳",
309
+ "大爷",
310
+ "大方",
311
+ "大意",
312
+ "大夫",
313
+ "多少",
314
+ "多么",
315
+ "外甥",
316
+ "壮实",
317
+ "地道",
318
+ "地方",
319
+ "在乎",
320
+ "困难",
321
+ "嘴巴",
322
+ "嘱咐",
323
+ "嘟囔",
324
+ "嘀咕",
325
+ "喜欢",
326
+ "喇嘛",
327
+ "喇叭",
328
+ "商量",
329
+ "唾沫",
330
+ "哑巴",
331
+ "哈欠",
332
+ "哆嗦",
333
+ "咳嗽",
334
+ "和尚",
335
+ "告诉",
336
+ "告示",
337
+ "含糊",
338
+ "吓唬",
339
+ "后头",
340
+ "名字",
341
+ "名堂",
342
+ "合同",
343
+ "吆喝",
344
+ "叫唤",
345
+ "口袋",
346
+ "厚道",
347
+ "厉害",
348
+ "千斤",
349
+ "包袱",
350
+ "包涵",
351
+ "匀称",
352
+ "勤快",
353
+ "动静",
354
+ "动弹",
355
+ "功夫",
356
+ "力气",
357
+ "前头",
358
+ "刺猬",
359
+ "刺激",
360
+ "别扭",
361
+ "利落",
362
+ "利索",
363
+ "利害",
364
+ "分析",
365
+ "出息",
366
+ "凑合",
367
+ "凉快",
368
+ "冷战",
369
+ "冤枉",
370
+ "冒失",
371
+ "养活",
372
+ "关系",
373
+ "先生",
374
+ "兄弟",
375
+ "便宜",
376
+ "使唤",
377
+ "佩服",
378
+ "作坊",
379
+ "体面",
380
+ "位置",
381
+ "似的",
382
+ "伙计",
383
+ "休息",
384
+ "什么",
385
+ "人家",
386
+ "亲戚",
387
+ "亲家",
388
+ "交情",
389
+ "云彩",
390
+ "事情",
391
+ "买卖",
392
+ "主意",
393
+ "丫头",
394
+ "丧气",
395
+ "两口",
396
+ "东西",
397
+ "东家",
398
+ "世故",
399
+ "不由",
400
+ "不在",
401
+ "下水",
402
+ "下巴",
403
+ "上头",
404
+ "上司",
405
+ "丈夫",
406
+ "丈人",
407
+ "一辈",
408
+ "那个",
409
+ "菩萨",
410
+ "父亲",
411
+ "母亲",
412
+ "咕噜",
413
+ "邋遢",
414
+ "费用",
415
+ "冤家",
416
+ "甜头",
417
+ "介绍",
418
+ "荒唐",
419
+ "大人",
420
+ "泥鳅",
421
+ "幸福",
422
+ "熟悉",
423
+ "计划",
424
+ "扑腾",
425
+ "蜡烛",
426
+ "姥爷",
427
+ "照顾",
428
+ "喉咙",
429
+ "吉他",
430
+ "弄堂",
431
+ "蚂蚱",
432
+ "凤凰",
433
+ "拖沓",
434
+ "寒碜",
435
+ "糟蹋",
436
+ "倒腾",
437
+ "报复",
438
+ "逻辑",
439
+ "盘缠",
440
+ "喽啰",
441
+ "牢骚",
442
+ "咖喱",
443
+ "扫把",
444
+ "惦记",
445
+ }
446
+ self.must_not_neural_tone_words = {
447
+ "男子",
448
+ "女子",
449
+ "分子",
450
+ "原子",
451
+ "量子",
452
+ "莲子",
453
+ "石子",
454
+ "瓜子",
455
+ "电子",
456
+ "人人",
457
+ "虎虎",
458
+ "幺幺",
459
+ "干嘛",
460
+ "学子",
461
+ "哈哈",
462
+ "数数",
463
+ "袅袅",
464
+ "局地",
465
+ "以下",
466
+ "娃哈哈",
467
+ "花花草草",
468
+ "留得",
469
+ "耕地",
470
+ "想想",
471
+ "熙熙",
472
+ "攘攘",
473
+ "卵子",
474
+ "死死",
475
+ "冉冉",
476
+ "恳恳",
477
+ "佼佼",
478
+ "吵吵",
479
+ "打打",
480
+ "考考",
481
+ "整整",
482
+ "莘莘",
483
+ "落地",
484
+ "算子",
485
+ "家家户户",
486
+ "青青",
487
+ }
488
+ self.punc = ":,;。?!“”‘’':,;.?!"
489
+
490
+ # the meaning of jieba pos tag: https://blog.csdn.net/weixin_44174352/article/details/113731041
491
+ # e.g.
492
+ # word: "家里"
493
+ # pos: "s"
494
+ # finals: ['ia1', 'i3']
495
+ def _neural_sandhi(self, word: str, pos: str, finals: List[str]) -> List[str]:
496
+ # reduplication words for n. and v. e.g. 奶奶, 试试, 旺旺
497
+ for j, item in enumerate(word):
498
+ if (
499
+ j - 1 >= 0
500
+ and item == word[j - 1]
501
+ and pos[0] in {"n", "v", "a"}
502
+ and word not in self.must_not_neural_tone_words
503
+ ):
504
+ finals[j] = finals[j][:-1] + "5"
505
+ ge_idx = word.find("个")
506
+ if len(word) >= 1 and word[-1] in "吧呢哈啊呐噻嘛吖嗨呐哦哒额滴哩哟喽啰耶喔诶":
507
+ finals[-1] = finals[-1][:-1] + "5"
508
+ elif len(word) >= 1 and word[-1] in "的地得":
509
+ finals[-1] = finals[-1][:-1] + "5"
510
+ # e.g. 走了, 看着, 去过
511
+ elif len(word) == 1 and word in "了着过" and pos in {"ul", "uz", "ug"}:
512
+ finals[-1] = finals[-1][:-1] + "5"
513
+ elif (
514
+ len(word) > 1
515
+ and word[-1] in "们子"
516
+ and pos in {"r", "n"}
517
+ and word not in self.must_not_neural_tone_words
518
+ ):
519
+ finals[-1] = finals[-1][:-1] + "5"
520
+ # e.g. 桌上, 地下, 家里
521
+ elif len(word) > 1 and word[-1] in "上下里" and pos in {"s", "l", "f"}:
522
+ finals[-1] = finals[-1][:-1] + "5"
523
+ # e.g. 上来, 下去
524
+ elif len(word) > 1 and word[-1] in "来去" and word[-2] in "上下进出回过起开":
525
+ finals[-1] = finals[-1][:-1] + "5"
526
+ # 个做量词
527
+ elif (
528
+ ge_idx >= 1
529
+ and (word[ge_idx - 1].isnumeric() or word[ge_idx - 1] in "几有两半多各整每做是")
530
+ ) or word == "个":
531
+ finals[ge_idx] = finals[ge_idx][:-1] + "5"
532
+ else:
533
+ if (
534
+ word in self.must_neural_tone_words
535
+ or word[-2:] in self.must_neural_tone_words
536
+ ):
537
+ finals[-1] = finals[-1][:-1] + "5"
538
+
539
+ word_list = self._split_word(word)
540
+ finals_list = [finals[: len(word_list[0])], finals[len(word_list[0]) :]]
541
+ for i, word in enumerate(word_list):
542
+ # conventional neural in Chinese
543
+ if (
544
+ word in self.must_neural_tone_words
545
+ or word[-2:] in self.must_neural_tone_words
546
+ ):
547
+ finals_list[i][-1] = finals_list[i][-1][:-1] + "5"
548
+ finals = sum(finals_list, [])
549
+ return finals
550
+
551
+ def _bu_sandhi(self, word: str, finals: List[str]) -> List[str]:
552
+ # e.g. 看不懂
553
+ if len(word) == 3 and word[1] == "不":
554
+ finals[1] = finals[1][:-1] + "5"
555
+ else:
556
+ for i, char in enumerate(word):
557
+ # "不" before tone4 should be bu2, e.g. 不怕
558
+ if char == "不" and i + 1 < len(word) and finals[i + 1][-1] == "4":
559
+ finals[i] = finals[i][:-1] + "2"
560
+ return finals
561
+
562
+ def _yi_sandhi(self, word: str, finals: List[str]) -> List[str]:
563
+ # "一" in number sequences, e.g. 一零零, 二一零
564
+ if word.find("一") != -1 and all(
565
+ [item.isnumeric() for item in word if item != "一"]
566
+ ):
567
+ return finals
568
+ # "一" between reduplication words shold be yi5, e.g. 看一看
569
+ elif len(word) == 3 and word[1] == "一" and word[0] == word[-1]:
570
+ finals[1] = finals[1][:-1] + "5"
571
+ # when "一" is ordinal word, it should be yi1
572
+ elif word.startswith("第一"):
573
+ finals[1] = finals[1][:-1] + "1"
574
+ else:
575
+ for i, char in enumerate(word):
576
+ if char == "一" and i + 1 < len(word):
577
+ # "一" before tone4 should be yi2, e.g. 一段
578
+ if finals[i + 1][-1] == "4":
579
+ finals[i] = finals[i][:-1] + "2"
580
+ # "一" before non-tone4 should be yi4, e.g. 一天
581
+ else:
582
+ # "一" 后面如果是标点,还读一声
583
+ if word[i + 1] not in self.punc:
584
+ finals[i] = finals[i][:-1] + "4"
585
+ return finals
586
+
587
+ def _split_word(self, word: str) -> List[str]:
588
+ word_list = jieba.cut_for_search(word)
589
+ word_list = sorted(word_list, key=lambda i: len(i), reverse=False)
590
+ first_subword = word_list[0]
591
+ first_begin_idx = word.find(first_subword)
592
+ if first_begin_idx == 0:
593
+ second_subword = word[len(first_subword) :]
594
+ new_word_list = [first_subword, second_subword]
595
+ else:
596
+ second_subword = word[: -len(first_subword)]
597
+ new_word_list = [second_subword, first_subword]
598
+ return new_word_list
599
+
600
+ def _three_sandhi(self, word: str, finals: List[str]) -> List[str]:
601
+ if len(word) == 2 and self._all_tone_three(finals):
602
+ finals[0] = finals[0][:-1] + "2"
603
+ elif len(word) == 3:
604
+ word_list = self._split_word(word)
605
+ if self._all_tone_three(finals):
606
+ # disyllabic + monosyllabic, e.g. 蒙古/包
607
+ if len(word_list[0]) == 2:
608
+ finals[0] = finals[0][:-1] + "2"
609
+ finals[1] = finals[1][:-1] + "2"
610
+ # monosyllabic + disyllabic, e.g. 纸/老虎
611
+ elif len(word_list[0]) == 1:
612
+ finals[1] = finals[1][:-1] + "2"
613
+ else:
614
+ finals_list = [finals[: len(word_list[0])], finals[len(word_list[0]) :]]
615
+ if len(finals_list) == 2:
616
+ for i, sub in enumerate(finals_list):
617
+ # e.g. 所有/人
618
+ if self._all_tone_three(sub) and len(sub) == 2:
619
+ finals_list[i][0] = finals_list[i][0][:-1] + "2"
620
+ # e.g. 好/喜欢
621
+ elif (
622
+ i == 1
623
+ and not self._all_tone_three(sub)
624
+ and finals_list[i][0][-1] == "3"
625
+ and finals_list[0][-1][-1] == "3"
626
+ ):
627
+ finals_list[0][-1] = finals_list[0][-1][:-1] + "2"
628
+ finals = sum(finals_list, [])
629
+ # split idiom into two words who's length is 2
630
+ elif len(word) == 4:
631
+ finals_list = [finals[:2], finals[2:]]
632
+ finals = []
633
+ for sub in finals_list:
634
+ if self._all_tone_three(sub):
635
+ sub[0] = sub[0][:-1] + "2"
636
+ finals += sub
637
+
638
+ return finals
639
+
640
+ def _all_tone_three(self, finals: List[str]) -> bool:
641
+ return all(x[-1] == "3" for x in finals)
642
+
643
+ # merge "不" and the word behind it
644
+ # if don't merge, "不" sometimes appears alone according to jieba, which may occur sandhi error
645
+ def _merge_bu(self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
646
+ new_seg = []
647
+ last_word = ""
648
+ for word, pos in seg:
649
+ if last_word == "不":
650
+ word = last_word + word
651
+ if word != "不":
652
+ new_seg.append((word, pos))
653
+ last_word = word[:]
654
+ if last_word == "不":
655
+ new_seg.append((last_word, "d"))
656
+ last_word = ""
657
+ return new_seg
658
+
659
+ # function 1: merge "一" and reduplication words in it's left and right, e.g. "听","一","听" ->"听一听"
660
+ # function 2: merge single "一" and the word behind it
661
+ # if don't merge, "一" sometimes appears alone according to jieba, which may occur sandhi error
662
+ # e.g.
663
+ # input seg: [('听', 'v'), ('一', 'm'), ('听', 'v')]
664
+ # output seg: [['听一听', 'v']]
665
+ def _merge_yi(self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
666
+ new_seg = []
667
+ # function 1
668
+ for i, (word, pos) in enumerate(seg):
669
+ if (
670
+ i - 1 >= 0
671
+ and word == "一"
672
+ and i + 1 < len(seg)
673
+ and seg[i - 1][0] == seg[i + 1][0]
674
+ and seg[i - 1][1] == "v"
675
+ and seg[i + 1][1] == "v"
676
+ ):
677
+ new_seg[i - 1][0] = new_seg[i - 1][0] + "一" + new_seg[i - 1][0]
678
+ else:
679
+ if (
680
+ i - 2 >= 0
681
+ and seg[i - 1][0] == "一"
682
+ and seg[i - 2][0] == word
683
+ and pos == "v"
684
+ ):
685
+ continue
686
+ else:
687
+ new_seg.append([word, pos])
688
+ seg = new_seg
689
+ new_seg = []
690
+ # function 2
691
+ for i, (word, pos) in enumerate(seg):
692
+ if new_seg and new_seg[-1][0] == "一":
693
+ new_seg[-1][0] = new_seg[-1][0] + word
694
+ else:
695
+ new_seg.append([word, pos])
696
+ return new_seg
697
+
698
+ # the first and the second words are all_tone_three
699
+ def _merge_continuous_three_tones(
700
+ self, seg: List[Tuple[str, str]]
701
+ ) -> List[Tuple[str, str]]:
702
+ new_seg = []
703
+ sub_finals_list = [
704
+ lazy_pinyin(word, neutral_tone_with_five=True, style=Style.FINALS_TONE3)
705
+ for (word, pos) in seg
706
+ ]
707
+ assert len(sub_finals_list) == len(seg)
708
+ merge_last = [False] * len(seg)
709
+ for i, (word, pos) in enumerate(seg):
710
+ if (
711
+ i - 1 >= 0
712
+ and self._all_tone_three(sub_finals_list[i - 1])
713
+ and self._all_tone_three(sub_finals_list[i])
714
+ and not merge_last[i - 1]
715
+ ):
716
+ # if the last word is reduplication, not merge, because reduplication need to be _neural_sandhi
717
+ if (
718
+ not self._is_reduplication(seg[i - 1][0])
719
+ and len(seg[i - 1][0]) + len(seg[i][0]) <= 3
720
+ ):
721
+ new_seg[-1][0] = new_seg[-1][0] + seg[i][0]
722
+ merge_last[i] = True
723
+ else:
724
+ new_seg.append([word, pos])
725
+ else:
726
+ new_seg.append([word, pos])
727
+
728
+ return new_seg
729
+
730
+ def _is_reduplication(self, word: str) -> bool:
731
+ return len(word) == 2 and word[0] == word[1]
732
+
733
+ # the last char of first word and the first char of second word is tone_three
734
+ def _merge_continuous_three_tones_2(
735
+ self, seg: List[Tuple[str, str]]
736
+ ) -> List[Tuple[str, str]]:
737
+ new_seg = []
738
+ sub_finals_list = [
739
+ lazy_pinyin(word, neutral_tone_with_five=True, style=Style.FINALS_TONE3)
740
+ for (word, pos) in seg
741
+ ]
742
+ assert len(sub_finals_list) == len(seg)
743
+ merge_last = [False] * len(seg)
744
+ for i, (word, pos) in enumerate(seg):
745
+ if (
746
+ i - 1 >= 0
747
+ and sub_finals_list[i - 1][-1][-1] == "3"
748
+ and sub_finals_list[i][0][-1] == "3"
749
+ and not merge_last[i - 1]
750
+ ):
751
+ # if the last word is reduplication, not merge, because reduplication need to be _neural_sandhi
752
+ if (
753
+ not self._is_reduplication(seg[i - 1][0])
754
+ and len(seg[i - 1][0]) + len(seg[i][0]) <= 3
755
+ ):
756
+ new_seg[-1][0] = new_seg[-1][0] + seg[i][0]
757
+ merge_last[i] = True
758
+ else:
759
+ new_seg.append([word, pos])
760
+ else:
761
+ new_seg.append([word, pos])
762
+ return new_seg
763
+
764
+ def _merge_er(self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
765
+ new_seg = []
766
+ for i, (word, pos) in enumerate(seg):
767
+ if i - 1 >= 0 and word == "儿" and seg[i - 1][0] != "#":
768
+ new_seg[-1][0] = new_seg[-1][0] + seg[i][0]
769
+ else:
770
+ new_seg.append([word, pos])
771
+ return new_seg
772
+
773
+ def _merge_reduplication(self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
774
+ new_seg = []
775
+ for i, (word, pos) in enumerate(seg):
776
+ if new_seg and word == new_seg[-1][0]:
777
+ new_seg[-1][0] = new_seg[-1][0] + seg[i][0]
778
+ else:
779
+ new_seg.append([word, pos])
780
+ return new_seg
781
+
782
+ def pre_merge_for_modify(self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
783
+ seg = self._merge_bu(seg)
784
+ try:
785
+ seg = self._merge_yi(seg)
786
+ except:
787
+ print("_merge_yi failed")
788
+ seg = self._merge_reduplication(seg)
789
+ try:
790
+ seg = self._merge_continuous_three_tones(seg)
791
+ except:
792
+ print("_merge_continuous_three_tones failed")
793
+ try:
794
+ seg = self._merge_continuous_three_tones_2(seg)
795
+ except:
796
+ print("_merge_continuous_three_tones_2 failed")
797
+
798
+ seg = self._merge_er(seg)
799
+ return seg
800
+
801
+ def modified_tone(self, word: str, pos: str, finals: List[str]) -> List[str]:
802
+ finals = self._bu_sandhi(word, finals)
803
+ finals = self._yi_sandhi(word, finals)
804
+ finals = self._neural_sandhi(word, pos, finals)
805
+ finals = self._three_sandhi(word, finals)
806
+ return finals
SongBloom/g2p/cn_zh_g2p/zh_normalization/README.md ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Supported NSW (Non-Standard-Word) Normalization
2
+
3
+ |NSW type|raw|normalized|
4
+ |:--|:-|:-|
5
+ |serial number|电影中梁朝伟扮演的陈永仁的编号27149|电影中梁朝伟扮演的陈永仁的编号二七一四九|
6
+ |cardinal|这块黄金重达324.75克<br>我们班的最高总分为583分|这块黄金重达三百二十四点七五克<br>我们班的最高总分为五百八十三分|
7
+ |numeric range |12\~23<br>-1.5\~2|十二到二十三<br>负一点五到二|
8
+ |date|她出生于86年8月18日,她弟弟出生于1995年3月1日|她出生于八六年八月十八日, 她弟弟出生于一九九五年三月一日|
9
+ |time|等会请在12:05请通知我|等会请在十二点零五分请通知我
10
+ |temperature|今天的最低气温达到-10°C|今天的最低气温达到零下十度
11
+ |fraction|现场有7/12的观众投出了赞成票|现场有十二分之七的观众投出了赞成票|
12
+ |percentage|明天有62%的概率降雨|明天有百分之六十二的概率降雨|
13
+ |money|随便来几个价格12块5,34.5元,20.1万|随便来几个价格十二块五,三十四点五元,二十点一万|
14
+ |telephone|这是固话0421-33441122<br>这是手机+86 18544139121|这是固话零四二一三三四四一一二二<br>这是手机八六一八五四四一三九一二一|
15
+ ## References
16
+ [Pull requests #658 of DeepSpeech](https://github.com/PaddlePaddle/DeepSpeech/pull/658/files)
SongBloom/g2p/cn_zh_g2p/zh_normalization/__init__.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from .text_normlization import *
SongBloom/g2p/cn_zh_g2p/zh_normalization/char_convert.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Traditional and simplified Chinese conversion, a simplified character may correspond to multiple traditional characters.
16
+ """
17
+ simplified_charcters = '制咖片型超声盘鉴定仔点他命书歌粉巾字帐恤手指记忆棒形转弯沟光○〇㐄㐅㐆㐌㐖毒㐜㐡㐤㐰㐺㑇㑳㒳㒸㔾㗂㗎㝵㞎㞙㞞以㢲㢴㤅㥁㥯㨗㫺㬎㮎㮚㮸㲋㲱㲾㳮涧㵪㶸㷖㷭㹢㹴犬㺢狓㺵碗㽮㿝䍃䔢䖟䖸䗈䗥䗪䝓射䥯䦉䯝鲃鱼䲔䳗鹅䵹鼄䶑一对应映射丁不识下儿子做二休世丘之貉并中台原则串为甚谓干净了百事无成八变五十些人得道鸡升天代如并来去个国政策劲幽灵在欧洲游荡接样萝卜坑侧化传价元论醇共再准刀两断切分耕耘收获钱货物向看旧就绪险刻千金动劳永逸匙零夜半卡通回复返影踪反常态口咬气句话同吐快吹周味呼诺呜品红锅哄而散起唱和问三知生熟团漆黑火糟堆场空块面塌糊涂尘染壁厢夔已足多情露水大早到晚夫妻当关万莫开失古恨套所料既往孔见提师要家主审寸阴难买斗牛小撮部阵局展身层巴掌帆风顺席地带过年计于春头载四季期被蛇怕井绳度愿式份弹顷深前律径心意念差愁孤行俱全房厅交遮打技长把抓死拿眼泪鼻涕钥锁折段抿拍即合扫排掬挥拨拥上入击洞掷揽改故辙败文值名斑方面旁族日秋餐隔雅里终父旦时晌会霎间晃暴寒曝更月望垠际朝夕本正经利杯羹东西板枝独秀根筋杆进条龙服务概模次函数又性程总付步脚印趋登毛拔呵氧氮碳决雌雄波未平派谎言流清楚白准溜烟潭有获闻是处降琴鹤甲病发可拾沙目然了直以相眨穿睹瞥瞬矢的解石鸟神教秉虔诚秘种窝蜂穷窍笑置笔苟勾销抹杀煞等奖箍节吃箭仇双雕诗筹箩筐系列纸级士官统丝毫挂维网尽线微吭响股脑胎脉承腔臂力致效资源址器举功投般说讲规贸易叶障着慎满皆输号木电池衣倾钟高低视仁觉醒览遗角银币触溃九鼎蔽抄出驷马追重语破贫洗贯走路安蹴至几蹶振跃役胆汗较辈轮辞赞退六连遍递边针血锤音错门思闪真倒项栽雾类保护川先惊乍体哄鳞爪鸣滴泡邻域党专鼓作齐炒丑烯亥克内酯冬加奴卯肝炎基尺梁街裤镐客宠庭巳汝昌烷玲磊糖肇酉醛啷青县韪良香骨鲷丂七集河市弦喜嘴张舌堵区工业姊妹星架构巧彩扭歪拼凑余热曜武州爷浮屠美乡老阶树荤素碎落能魄鳃鳗珠丄丅丆万俟丈尚摸母娘量管群亚虎必我堂令申件装伏位博侠义界表女墟台戏臭皮匠胜诸葛亮赛顶倍催请运算包立叉戟离疫苗土史志演围揭瓦晒夷姑婆帝村宝烂尖杉碱屉桌山岔岛由纪峡坝库镇废从德后拗汤治旬食明昧曹朋友框栏极权幂曲归依猫民氟硼氯磷铁江侗自旅法司洋浦梅园温暖湾焦班幸用田略番叠皇炮捶硝苯酸腺苷棱草镜穗跳远索锦纲聚氰胺联店胚膲爱色堇紫罗兰芝茶饭菱云虫藏藩乱叛苏亲债凳学座恐恋柱测肌腹衩锥系貂企乌跪叩军车农题迭都甘油屯奏键短阿姨陪姐只顾茅庐槽驾魂鲜鹿页其菜单乘任供势午齿汉组织吊调泻唇坡城报坟外夸将尉建筑岸岗公床扬新剑升杭林栗校楼标款汽社浣海商馆剧院钢华港机械广媒环球融第医科证券综财乐育游涨犹岭疏瘾睑确兵领导缴肢膛船艾瑟尔苍蔡虞效衫覆访诉课谕议轨述野钩限敌鞋颌颔颚饶首龈站例修凡划垂届属崽颏厨拜挫摆放旋削棋榻槛礼沉注滑营狱画确仪聘花葬诏员跌辖周达酒锚闸陷陆雨雪飞威丌于丹久乏予理评产亢卑亦乎舞己悲矩圆词害志但住佞佳便俗信票案幅翁倦伦假偏倚斜亏鬼敲停备伤脾胃仅此像俭匮免宜穴焉戴兼容许冻伯仲负彼昼皂轩轾实刊划颠卫战哥比省非好黄饰别拘束掩奶睬选择摇扰烦苦枚写协厌及格受欢迎约只估侵犯割状告或缺抗拒挽撤救药喻磨灭端倪少逆逾越避靠适吉誉吝玉含延咎歹听啻渊善谋均匀堪忍够太惹妙妥妨孕症孝术室完纳推冠积宣疑辩栗碴称屈挠屑干涉衡待很忙恶忿怎么怠急耻恭息悦惑惜惟想愉愧怍慌愤启懂懈怀材才紧招认扣抵拉舍也罢插揣冒搭撞南墙扩核支攻敢雷攀敬里吗需景智暇曾罪遇朽枉止况竞争辱求愈渝溶济左右袒困补爽特寂寞示弱找谢畏强疾徐痛痒冤符眠睦瞅董何厚云措活疲羞者轻玻璃祥兆禁���稂莠稳佛换答简结果盟绝缕途给谈否羁翼耐肖胫毋宁兴舒若菲莱痕迹窠臼虚衰脸兔撒鹰棺范该详讳抬泰让须眉象众赀账费灰赖奇虑训辍辨菽麦辛近送透逞徒速续逮捕遂遑违逊斧钺艰醉锈随观弃显饱脂肪使丏丐帮丒且慢末丕替桃宗王尊凉爵各图屋脊粮署录坛吾禄职胄袭君厦丗北壑桐疹损逢陵鹬丙寅戌氨腈唑纶辰酮脱氢酶醚丞丢现掉纱帽弄扯炮碗丠両丣坐存激肩臻蒂莲悖序驱丨丩丫挺杈髻鬟细介俄伊犁京尼布订普渡央委监察检查剂圈设警队斯督剩震境航舶革防托播促质版蝾螈锋研艺历残消频谱精密制造陲邮候埔坚压坜凹汇执府究邦俘摄寮彬狼岳肺肿庸英讯诊埋粒胞括控码韩暑枪枢砥澳哇牟寿甸钻探篇签缀缝继耳肯照妇埃悬璧轴柜台辣搁浅邪跑纤阮阳私囊魔丮丰姿采丱烧丳丵丶丷丸参寨朗桂瑞砂衷霞貌凤仆舰因嫌宰峰干络牌持旨祭祷簿编罚宾办丼丿乀乂乃乄仰慕盛旷留考验阔乆乇么丑麽乊湖燃乑乒乓乕乖僻忤戾离谬迕乗危肥劫除隙浪婿乙炔肠酰吡咯盐乚乛乜嘢卿玄宫尾狐龟塔嶷兄弟泉章霄钉耙乞扎哀怜恕讨乢乣乤乥乧乨乩童乪乫乭乳晕汁液瑶浆牙癌突窦罩腐胶猪酪蛋糕菌瘤乴乵乶乷乸乹乺乼乾俸冰嘉哕嚎坤妈尸垒旱枯涸俐渴潮涩煸豆燥爹瘦瘪癣瞪袋脆姜贝隆馏乿亀亁叫咕攘扔搞男砸窜蓬麻亃亄亅却亇迟典今临繁累卵奉婚聪躬巨与迁添裂副宿岁怪恶尕仑愣杆硅硫钛铀锰芑杂异钠砷胂磺琥珀舱棍簧胡茬盗浩盆贩郎腿亍洪亐互欠助勉惠操斥诿系户译亓墓碑刑铃卅渠缤纷斗米旗宪钒灯徽瘟祖拳福谷丰脏腑绑肉腌苓蕴桥铺霸颜闹判喷冈底蛙陉矿亖亘亜罕们娜桑那努哈喀弗烈曼松森杜氏杯奥琛敦戊穆圣裔汇薛孙亟亡佚虏羊牢奋释卷卸契媾感额睫缠谊趾塞挤纽阻还配驰庄亨洛祚亪享津沪畿郊慈菴枇杷膏亭阁锃丽亳亶亹诛初责翻疯偶杰丛稠妖拖寰居吸授慧蜗吞壮魅狗矛盾益渣患忧稀描猿梦暂涯畜祸缘沸搜引擎臣横纭谁混援蒸兽狮税剖亻亼亽亡什献刹邡么仂仃仄仆富怨仈仉毕昔晨壳绍仍仏仒仕宦仗欺恃腰叹叹炬梓讫施仙后琼逝仚仝仞仟悔仡佬偿填泊拓扑簇羔购顿钦佩发棻阃驭养亿儆尤借帧赈凌叙帖李柔刚沃眦睚戒讹取飨读仨仫仮著泳卧躺韶夏裁仳仵唯贤凭钓诞仿似宋佛讽伀硕盼鹅伄儅伈伉俪柯始娃迈戈坦堡帕茨萨庙玛莉莎藤霍姆伋伍奢胥廷芳豪伎俩侍汛勒希羲雏伐憩整谟闲闲伕伙伴颐伜伝伢叔恒兹恩翰伱伲侣伶俜悧鼬伸懒缩喇叭伹伺伻伽倻辐伾似佃伫布乔妮墨佉卢佌贷劣廉昂档浓矮伞洼缓耗胸谷迷挡率龋宅沫舍疗佐贰佑占优据铧尝呢须鲁晓佗佘余坪寺瓜铳僧蒙芒陀龛哼呕坊奸孽弊揖祟茧缚誓贼佝偻瞀佟你夺赶佡佢佣佤佧贾佪佫佯佰佱洁绩酿肴佴卷佶佷佸佹佺佻佼佽佾具唤窘坏娱怒慨硬习惯聋膨胀蔓骇贵痹侀侁侂侃侄侅鸿燕侇侈糜靡侉侌妾侏儒仓鼠侐侑侔仑侘侚链侜偎傍钴循柳葫芦附価侮骂蔑侯岩截蚀局贴壶嬛宴捷携桶笺酌俣狭膝狄俅俉俊俏俎俑俓俔谚俚俛黎健呈固墒增守康箱湿祐镖镳杠盒靖膜龄俞豹猎噪孚封札筒托衍鸽剪撰稿炼厂禊练缮葺俯瞰撑冲效俳俴俵俶俷俺备俾伥倂倅储卒惶敷猝逃颉蓄崇隐倌倏忽刺蜡烛噍嚼坍扁抽毙葱楣灌灶粪背薮卖赔闭霉腾倓倔幸倘倜傥倝借箸挹浇阅倡狂倢倣値倥偬倨傲倩匡嗣冲柝珍倬倭寇猩倮倶倷倹勤赞偁偃充伪吏嗓寐惺扮拱芫茜藉虢钞偈伟晶偌宕距析滤殿疼瘫注颇偓偕鸭歇滞偝偟偢忘怡旺偨偩逼偫偭偯偰偱偲侦缉蹄偷减惰漏窥窃偸偺迹傀儡傅傈僳骂篱傎奎琳迪叟芭傒傔傕伧悉荒傜傞傢傣芽逼佣婢傮睨寄檄诵谣颂伛担辜弓惨蒿悼疤傺傻屄臆巢泄箧羡盖轧颓傿㑩僄僇佥僊働僎侨僔僖僚僝伪僣僤侥僦猴偾僩僬僭僮僯僰雇僵殖签静僾僿征陇儁侬儃儇侩朴薄儊儋儌儍傧儓俦侪拟尽儜儞儤儦儩汰哉寡渥裕酷儭儱罐儳儵儹傩俨儽兀臬臲鹫允勋勋宙宵帅憝彝谐嫂阋畅沛溢盈饥赫凶悍狠猛顽愚妣斩秦遣鞭耀敏荣槃泽爆碟磁秃缆辉霁卤朵娄孜烽酱勃汀箕裘钳耶蒙蕾彻兑软遭黜兎児韵媳爸兕觥兖兙兛兜售鍪肚兝兞兟兡兢兣樽殓涅睡禀籍赘泌啡肽奸幕涵涝熵疚眷稃衬讧赴焕椒歼植跏没试误猜栖窗肋袖颊兪卦撇胡岐廓轿疸枫茴珑厕秩募勺吨寓斤历亩迫筷厘最淫螺韬兮宽匪筛襄赢轭复兲诈刃堰戎痞蚁饷它冀铸冂冃円冇冉册嫁厉砺竭醮冏牧冑冓冔冕冖冗冘冞冢窄抑诬冥冫烘菇蛰冷凝坨橇淇淋炭饼砖碛窖醋雕雹霜冱冶炉艳嘲峻滩淡漠煖飕饮冼冽凃凄怆梗凅凇净凊凋敝蒙凔凛遵汞脢凞几凢処凰凯凵凶焰凸折刷纹预丧喽奔巡榜殡芙蓉租笼辑鞘萃凼锯镬刁蛮刂娩崩批拆摊掰蘖骤歧颗秒袂赃勿嘱忌磋琢肤刈羽刎讼戮舂桨艇刓刖霹雳刜创犊刡恙墅帜筵致劫劫刨昏默攸尿欲熏润薰圭删刮痧铲刱刲刳刴刵踏磅戳柏槐绣芹苋猬舟铭鹄鹜劫剁剃辫刭锉履铅克剌姻咽哨廊掠桅沿召瞻翅赵卜渺茫郭剒剔剕沥剚愎毅讷才剜剥啄采剞剟剡剣剤䌽剐肾驶黏剰袍剀紊铲剸剺剽剿劁劂札劈啪柴扳啦刘奭姥夼昫涓熙禅禹锡翔雁鹗刽刿弩柄蜻蛉劒劓劖劘劙澜篑赏矶釜晋甜薪逐劦熔纣虐赤囚劬劭労劵效劻劼劾峭艮勅勇励勍勐腊脖庞漫饲荡粥辄勖勗勘骄馁碌泮雇捐竹骑殊阱绩朴恳谨剿勧勩勯勰劢勋勷劝惩慰诫谏勹芡践阑匁庇拯粟扎袱裹饺匆遽匈匉匊匋匍匐茎匏匕妆痰脓蛹斋苑烤蹈塘羌熊阀螳螂疆碚竿纬荷茵邙魏匚匜匝匟扶稷匣匦拢匸匹耦匽匾匿卂叮疮禧轸堤棚迢钧炼卄卆遐卉瓷盲瓶当胱腱裸卋卌卍卐怯污贱鄙龌龊陋卓溪唐梯渔陈枣泥漳浔涧梨芬谯赡辕迦郑単驴弈洽鳌卛占筮卝卞卟吩啉屎翠厄卣卨卪卬卮榫袄玺绶钮蚤惧殆笃耸卲帘帙绕恤卼卽厂厎厓厔厖厗奚厘厍厜厝谅厕厤厥厪腻孢厮厰厳厣厹厺粕垢芜菁厼厾叁悟茸薯叄吵笄悌哺讥坫垄弧芯杠潜婴刍袁诘贪谍煽馈驳収岳缔灾贿骗叚叡吻拦蘑蜜诀燧玩砚筝椎蔺铜逗骊另觅叨唠谒杵姓喊嚷嚣咚咛塑寻恼憎擦只泣渗蝠叱吒咄咤喝籀黛舵舷叵叶铎懿昭穰苴辽叻叼吁堑嫖赌瞧爬众抒吅吆夥卺橡涤抱纵摩郡唁坠扇篮膀袜颈吋忾谘酬哭妓媛暗表缰迩妃羿絮蕃浑拐葵暮隅吔吖啶嗪戚吜啬噬咽吟哦咏吠吧唧嗒咐吪隽咀征燐苞茹钙哧吮吰吱嘎吲哚吴栋娇窟孟箫忠晗淞阖闾趼宇呐睛嘘拂捧疵熄竽笛糠吼吽呀吕韦蒙呃呆笨呇贡呉罄呋喃呎呏呔呠呡痴呣呤呦呧瑛眩扒晬淑姬瑜璇鹃呪呫哔嚅嗫呬呯呰呱呲咧噌钝呴呶呷呸呺呻哱咻啸噜吁坎坷逻呿咁咂咆哮咇咈咋蟹煦珅蔼咍咑咒诅咔哒嚓咾哝哩喱咗咠咡咢咣咥咦咨嗟询咩咪咫啮啮咭咮咱咲咳呛嗽咴啕咸咹咺呙喉咿婉恸悯赋矜绿茗蓝哂抢瞒哆嗦啰噻啾滨彗哋哌哎唷哟哏哐哞哢哤哪里哫啼喘哰哲萎蚌哳咩哽哿呗唅唆唈唉唎唏哗尧棣殇璜睿肃唔睇唕吣唞唣喳唪唬唰喏唲唳唵嘛唶唸唹唻唼唾唿啁啃鹦鹉啅埠栈榷祺铺鞅飙啊啍啎啐啓啕啖啗啜哑祈啢衔啤啥啫啱啲啵啺饥啽噶昆沁喁喂喆裙喈咙喋喌喎喑喒喓喔粗喙幛庆滋鹊喟喣喤喥喦喧骚喨喩梆吃葡萄喭驼挑吓碰枞瓣纯疱藻趟铬喵営喹喺喼喿嗀嗃嗄嗅嗈嗉嗊嗍嗐嗑嗔诟嗕嗖嗙嗛嗜痂癖嗝嗡嗤嗥嗨唢嗬嗯嗰嗲嗵叽嗷嗹嗾嗿嘀嘁嘂嘅惋嘈峪禾荫啀嘌嘏嘐嘒啯啧嘚唛嘞嘟囔嘣嘥嘦嘧嘬嘭这谑严敞馋松哓嘶嗥呒虾嘹嘻啴嘿噀噂噅噇噉噎噏噔噗噘噙噚咝噞噢噤蝉皿噩噫噭嗳噱哙噳嚏涌洒欲巫霏噷噼嚃嚄嚆抖哜尝嚔苏嚚嚜嚞嚟呖嚬嚭嚮嚯亸喾饬按竣苛嚵嘤啭冁呓膪谦囍囒囓囗囘萧酚飘溅谛囝溯眸纥銮鹘囟殉囡団囤囥囧囨囱囫囵囬囮囯囲図囶囷囸囹圄圉拟囻囿圀圂圃圊粹蠹赦圌垦圏滚鲱凿枘圕圛圜圞坯埂壤骸炕祠窑豚绅魠鲮鳖圧握圩圪垯圬圮圯炸岬幔毯祇窨菩溉圳圴圻圾坂坆沾坋坌舛壈昆垫墩椅坒坓坩埚坭坰坱坳坴坵坻坼杨挣涎帘垃垈垌垍垓垔垕垗垚垛垝垣垞垟垤垧垮垵垺垾垿埀畔埄埆埇埈埌殃隍埏埒埕埗埜垭埤埦埧埭埯埰埲埳埴埵埶绋埸培怖桩础辅埼埽堀诃侄庑堃堄摧磐贞韧砌堈堉垩堋堌堍堎垴堙堞堠礁堧堨舆堭堮蜓摘堲堳堽堿塁塄塈煤茔棵塍垲埘塓绸塕鸦沽虱塙冢塝缪塡坞埙塥塩塬塱场螨塼塽塾塿墀墁墈墉墐夯増毁墝墠墦渍钵墫墬堕墰墺墙橱壅壆壊壌壎壒榨蒜壔壕壖圹垆壜壝垅壡壬壭壱売壴壹壻壸寝壿夂夅夆変夊夌漱邑夓腕泄甥御骼夗夘夙衮瑙妊娠醣枭珊莺鹭戗幻魇夤蹀秘擂鸫姚宛闺屿庾挞拇賛蛤裨菠氅漓捞湄蚊霆鲨箐篆篷荆肆舅荔鲆巷惭骰辟邱镕镰阪漂烩鲵鲽鳄鸨胪鹏妒峨谭枰晏玑癸祝秤竺牡籁恢罡蝼蝎赐绒御梭夬夭砣榆怙枕夶夹馅奄崛葩谲奈贺祀赠奌奂奓奕䜣詝奘奜奠奡奣陶奨奁魁奫奬奰娲孩贬隶酥宄狡猾她姹嫣妁毡荼皋膻蝇嫔妄妍嫉媚娆妗趣妚妞妤碍妬娅妯娌妲妳妵妺姁姅姉姗姒姘姙姜姝姞姣姤姧姫姮娥姱姸姺姽婀娀诱慑胁娉婷娑娓娟娣娭娯娵娶娸娼婊婐婕婞婤婥溪孺婧婪婬婹婺婼婽媁媄媊媕媞媟媠媢媬媮妫媲媵媸媺媻媪眯媿嫄嫈袅嫏嫕妪嫘嫚嫜嫠嫡嫦嫩嫪毐嫫嫬嫰妩嫺娴嫽嫿妫嬃嬅嬉耍婵痴艳嬔嬖嬗嫱袅嫒嬢嬷嬦嬬嬭幼嬲嬴婶嬹嬾嬿孀娘孅娈孏曰癫屏孑孓雀孖斟篓谜摺孛矻鸠崮轲祜鸾孥邈毓棠膑孬孭孰孱孳孵泛罔衔孻孪宀宁冗拙株薇掣抚琪瓿榴谧弥宊濂祁瑕宍宏碁宓邸谳実潢町宥宧宨宬徵崎骏掖阙臊煮禽蚕宸豫寀寁寥寃檐庶寎暄碜寔寖寘寙寛寠苫寤肘洱滥蒗陕核寪弘绰螽宝擅疙瘩晷対檐専尃尅赎绌缭畴衅尌峙醌襟痲碧屁昊槌淘恵瀑牝畑莓缸羚觑蔻脏躁尔尓锐尗尙尜尟尢��尨尪尬尭尰擒尲尶尴尸尹潽蠖蛾尻扣梢蚴鳍脬蹲屇屌蚵屐屃挪屖屘屙屛屝屡屣峦嶂岩舄屧屦屩屪屃屮戍驻钾崖嵛巅旮旯楂榄榉芋茱萸靛麓屴屹屺屼岀岊岌岍阜岑彭巩岒岝岢岚岣岧岨岫岱岵岷峁峇峋峒峓峞峠嵋峨峰峱岘峹峿崀崁崆祯崋崌崃岖昆崒崔嵬巍萤颢崚崞崟崠峥巆崤崦崧殂岽崱崳崴崶崿嵂嵇嵊泗嵌嵎嵒嵓岁嵙嵞嵡嵩嵫嵯嵴嵼嵾嵝崭崭晴嶋嶌嶒嶓嵚崂嶙嶝嶞峤嶡嶢峄嶨嶭嶮嶰嶲岙嵘巂巃巇巉岿巌巓巘巛滇芎巟巠弋回巣巤炊擘蜥蟒蛊觋巰蜀彦淖杏茂甫楞巻巽帼巿帛斐鲫蕊帑帔帗帚琉汶帟帡帣帨裙帯帰帷帹暆帏幄帮幋幌幏帻幙帮幞幠幡幢幦幨幩幪帱幭幯幰遥蹉跎馀庚鉴幵幷稚邃庀庁広庄庈庉笠庋跋庖牺庠庤庥鲸庬庱庳庴庵馨衢庹庿廃厩廆廋廌廎廏廐廑廒荫廖廛厮搏锣廞弛袤廥廧廨廪廱绵踵髓廸迫瓯邺廻廼廾廿躔弁皱弇弌弍弎弐弑吊诡憾荐弝弢弣弤弨弭弮弰弪霖繇焘斌旭溥骞弶弸弼弾彀彄别累纠强彔彖彘彟彟陌彤贻彧绘虹彪炳雕蔚鸥彰瘅彲彳彴仿彷徉徨彸彽踩敛旆徂徇徊渭畲铉裼従筌徘徙徜徕膳苏萌渐徬徭醺徯徳徴潘徻徼忀瘁胖燎怦悸颤扉犀澎湃砰恍惚绞隘忉惮挨饿忐忑忒忖応忝忞耿忡忪忭忮忱忸怩忻悠懑怏遏怔怗怚怛怞怼黍讶怫怭懦怱怲恍怵惕怸怹恁恂恇恉恌恏恒恓恔恘恚恛恝恞恟恠恣恧眄恪恫恬澹恰恿悀悁悃悄悆悊悐悒晦悚悛悜悝悤您悩悪悮悰悱凄恻德悴怅惘闷悻悾惄愫钟蒐惆惇惌惎惏惓惔惙惛耄惝疟浊恿惦德恽惴蠢惸拈愀愃愆愈愊愍愐愑愒愓愔愕恪氓蠢騃昵惬赧悫愬愮愯恺愼慁恿慅慆慇霭慉慊愠慝慥怄怂慬慱悭慴慵慷戚焚憀灼郁憃惫憋憍眺捏轼愦憔憖憙憧憬憨憪憭怃憯憷憸憹憺懃懅懆邀懊懋怿懔懐懞懠懤懥恹懫懮懰懱毖懵遁梁雍忏懽戁戄戆戉戋戕戛戝戛戠戡戢戣戤戥戦戬戭戯轰戱披菊牖戸戹戺戻卯戽锹扂楔扃扆扈扊杖牵绢铐镯赉扐搂搅烊盹瞌跟趸镲靶鼾払扗玫腮扛扞扠扡扢盔押扤扦扱罾揄绥鞍郤窾扻扼扽抃抆抈抉抌抏瞎抔缳缢擞抜拗択抨摔歉蹿牾抶抻搐泵菸拃拄拊髀抛拌脯拎拏拑擢秧沓曳挛迂拚拝拠拡拫拭拮踢拴拶拷攒拽掇芥橐簪摹疔挈瓢骥捺蹻挌挍挎挐拣挓挖掘浚挙揍聩挲挶挟挿捂捃捄捅捆捉捋胳膊揎捌捍捎躯蛛捗捘捙捜捥捩扪捭据捱捻捼捽掀掂抡臀膘掊掎掏掐笙掔掗掞棉芍掤搪阐掫掮掯揉掱掲掽掾揃揅揆搓揌诨揕揗揘揜揝揞揠揥揩揪揫橥遒麈揰揲揵揶揸背揺搆搉搊搋搌搎搔搕撼橹捣搘搠搡搢搣搤搥搦搧搨搬楦裢讪赸掏搰搲搳搴揾搷搽搾搿摀摁摂摃摎掴摒摓跤摙摛掼摞摠摦喉羯摭摮挚摰摲抠摴抟摷掺摽撂撃撅稻撊撋挦锏泼撕撙撚㧑挢撢掸撦撅撩撬撱朔揿蚍蜉挝捡擀掳闯擉缶觚擐擕擖擗擡擣擤澡腚擧擨擩擫擭摈拧撷擸撸擽擿攃摅撵攉攥攐攓撄搀撺每攩攫辔澄攮攰攲攴轶攷砭讦攽碘敁敃敇敉叙敎筏敔敕敖闰诲敜煌敧敪敳敹敺敻敿斁衽斄牒绉诌斉斎斓鹑谰驳鳢斒筲斛斝斞斠斡斢斨斫斮晾沂潟颖绛邵斲斸釳於琅斾斿旀旗旃旄涡旌旎旐旒旓旖旛旝旟旡旣浴旰獭魃旴时旻旼旽昀昃昄昇昉晰躲澈熹皎皓矾昑昕昜昝昞昡昤晖笋昦昨是昱昳昴昶昺昻晁蹇隧蔬髦晄晅晒晛晜晞晟晡晢晤晥曦晩萘莹顗晿暁暋暌暍暐暔暕煅旸暝暠暡曚暦暨暪朦胧昵暲殄冯暵暸暹暻暾曀晔昙曈曌曏曐暧曘曙曛叠昽曩骆曱甴肱曷牍禺锟曽沧耽朁朅朆杪栓夸竟粘绦朊膺朏朐朓朕朘朙瞄觐溘饔飧朠朢朣栅椆淀虱朩朮朰朱炆璋钰炽鹮朳槿朵朾朿杅杇杌陧欣钊湛漼楷瀍煜玟缨翱肇舜贽适逵杓杕杗杙荀蘅杝杞脩珓筊杰榔狍閦颦缅莞杲杳眇杴杶杸杻杼枋枌枒枓衾葄翘纾逋枙狸桠枟槁枲枳枴枵枷枸橼枹枻柁柂柃柅柈柊柎某柑橘柒柘柙柚柜柞栎柟柢柣柤柩柬柮柰柲橙柶柷柸柺査柿栃栄栒栔栘栝栟柏栩栫栭栱栲栳栴檀栵栻桀骜桁镁桄桉桋桎梏椹葚桓桔桕桜桟桫椤桭杯桯桲桴桷桹湘溟梃梊梍梐潼栀枧梜梠梡梣梧梩梱梲梳梴梵梹棁棃樱棐棑棕榈簑绷蓑枨棘棜棨棩棪棫棬棯棰棱棳棸棹椁棼碗椄苕椈椊椋椌椐椑椓椗検椤椪椰椳椴椵椷椸椽椿楀匾楅篪楋楍楎楗楘楙楛楝楟楠楢楥桢楩楪楫楬楮楯楰梅楸楹楻楽榀榃榊榎槺榕榖榘榛狉莽搒笞榠榡榤榥榦榧杩榭榰榱梿霰榼榾桤槊闩槎槑槔槖様槜槢槥椠槪槭椮槱槲槻槼槾樆樊樏樑樕樗樘樛樟樠樧樨権樲樴樵猢狲桦樻罍樾樿橁橄橆桡笥龠橕橚橛辆椭橤橧竖膈跨橾橿檩檃檇柽檍檎檑檖檗桧槚檠樯檨檫檬梼槟檴檵柠棹櫆櫌栉櫜椟櫡槠栌枥榇栊櫹棂茄櫽欀欂欃欐欑栾欙棂溴欨欬欱欵欶欷歔欸欹欻欼欿歁歃歆艎歈歊莳蝶歓歕歘歙歛歜欤歠蹦诠镶蹒跚升陟歩歮歯歰歳歴璞歺瞑歾殁夭殈殍殑殗殜殙殛殒殢殣殥殪殚僵殰殳荃殷殸殹蛟殻肴谤殴毈毉喂毎���蕈毗毘毚茛邓毧毬毳毷毹毽毾毵牦氄氆靴氉氊氇氍氐聊氕氖気氘氙氚氛氜氝氡汹焊痉氤氲氥氦铝锌氪烃氩铵痤汪浒漉痘盂碾菖蒲蕹蛭螅氵冰氹氺氽烫氾氿渚汆汊汋汍汎汏汐汔汕褟汙汚汜蓠沼秽蔑汧汨汩汭汲汳汴堤汾沄沅沆瀣沇沈葆浸沦湎溺痼疴沌沍沏沐沔沕沘浜畹砾沚沢沬沭沮沰沱灢沴沷籽沺烹濡洄泂肛泅泆涌肓泐泑泒泓泔泖泙泚泜泝泠漩馍涛粼泞藓鳅泩泫泭泯铢泱泲洇洊泾琵琶荽蓟箔洌洎洏洑潄濯洙洚洟洢洣洧洨洩痢滔洫洮洳洴洵洸洹洺洼洿淌蜚浄浉浙赣渫浠浡浤浥淼瀚浬浭翩萍浯浰蜃淀苔蛞蝓蜇螵蛸煲鲤浃浼浽溦涂涊涐涑涒涔滂莅涘涙涪涫涬涮涴涶涷涿淄淅淆淊凄黯淓淙涟淜淝淟淠淢淤渌淦淩猥藿亵淬淮淯淰淳诣涞纺淸淹炖癯绮渇済渉渋渓渕涣渟渢滓渤澥渧渨渮渰渲渶渼湅湉湋湍湑湓湔黔湜湝浈湟湢湣湩湫湮麟湱湲湴涅満沩溍溎溏溛舐漭溠溤溧驯溮溱溲溳溵溷溻溼溽溾滁滃滉滊荥滏稽滕滘汇滝滫滮羼耷卤滹浐煎漈漊漎绎漕漖漘漙沤漜漪漾漥漦漯漰溆漶漷濞潀颍潎潏潕潗潚潝潞潠潦祉疡潲潵滗潸潺潾涠澁澂澃澉澌澍澐澒澔澙渑澣澦澧澨澫澬浍澰澴澶澼熏郁濆濇濈濉濊貊濔疣濜濠濩觞浚濮盥潍濲泺瀁滢渎渖瀌浏瀒瀔濒泸瀛潇潆瀡潴泷濑瀬弥潋瀳瀵瀹瀺瀼沣滠灉灋灒漓灖灏灞灠滦灥灨滟灪蜴灮烬獴灴灸灺炁炅鱿炗炘炙炤炫疽烙钎炯炰炱炲炴炷毁炻烀烋瘴鲳烓烔焙烜烝烳饪烺焃焄耆焌焐焓焗焜焞焠焢焮焯焱焼煁煃煆煇煊熠煍熬煐炜煕暖熏硷霾煚煝煟煠茕矸煨琐炀萁煳煺煻熀熅熇熉罴荧穹炝熘熛熜稔谙烁熤熨熯熰眶蚂颎熳熸熿燀烨燂燄盏燊燋燏燔隼燖焖燠燡灿燨燮燹燻燽燿爇爊爓爚爝爟爨蟾爯爰为爻丬爿牀牁牂牄牋窗牏牓窗釉牚腩蒡虻牠虽蛎牣牤牮牯牲牳牴牷牸牼绊牿靬犂犄犆犇犉犍犎犒荦犗犛犟犠犨犩犪犮犰狳犴犵犺狁甩狃狆狎狒獾狘狙黠狨狩狫狴狷狺狻豕狈蜘猁猇猈猊猋猓猖獗猗猘狰狞犸猞猟獕猭猱猲猳猷猸猹猺玃獀獃獉獍獏獐獒毙獙獚獜獝獞獠獢獣獧鼇蹊狯猃獬豸狝獯鬻獳犷猕猡玁菟玅玆玈珉糁禛郅玍玎玓瓅玔玕玖玗玘玞玠玡玢玤玥玦珏瑰玭玳瑁玶玷玹玼珂珇珈瑚珌馐馔珔珖珙珛珞珡珣珥珧珩珪佩珶珷珺珽琀琁陨玡琇琖琚琠琤琦琨琫琬琭琮琯琰琱琲琅琴珐珲瑀瑂瑄瑉玮瑑瑔瑗瑢瑭瑱瑲瑳瑽瑾瑿璀璨璁璅璆璈琏璊璐璘璚璝璟璠璡璥瑷璩璪璫璯璲玙璸璺璿瓀璎瓖瓘瓒瓛脐瓞瓠瓤瓧瓩瓮瓰瓱瓴瓸瓻瓼甀甁甃甄甇甋甍甎甏甑甒甓甔瓮甖甗饴蔗甙诧钜粱盎锈团甡褥産甪甬甭甮宁铠甹甽甾甿畀畁畇畈畊畋畎畓畚畛畟鄂畤畦畧荻畯畳畵畷畸畽畾疃叠疋疍疎箪疐疒疕疘疝疢疥疧疳疶疿痁痄痊痌痍痏痐痒痔痗瘢痚痠痡痣痦痩痭痯痱痳痵痻痿瘀痖瘃瘈瘉瘊瘌瘏瘐痪瘕瘖瘙瘚瘛疭瘜瘝瘗瘠瘥瘨瘭瘆瘯瘰疬瘳疠瘵瘸瘺瘘瘼癃痨痫癈癎癐癔癙癜癠疖症癞蟆癪瘿痈発踔绀蔫酵皙砬砒翎翳蔹钨镴皑鹎驹暨粤褶皀皁荚皃镈皈皌皋皒朱皕皖皘皜皝皞皤皦皨皪皫皭糙绽皴皲皻皽盅盋碗盍盚盝踞盦盩秋千盬盭眦睁瞤盯盱眙裰盵盻睐眂眅眈眊県眑眕眚眛眞眢眣眭眳眴眵眹瞓眽郛睃睅睆睊睍睎困睒睖睙睟睠睢睥睪睾睯睽睾眯瞈瞋瞍逛瞏瞕瞖眍䁖瞟瞠瞢瞫瞭瞳瞵瞷瞹瞽阇瞿眬矉矍铄矔矗矙瞩矞矟矠矣矧矬矫矰矱硪碇磙罅舫阡、矼矽礓砃砅砆砉砍砑砕砝砟砠砢砦砧砩砫砮砳艏砵砹砼硇硌硍硎硏硐硒硜硖砗磲茚钡硭硻硾碃碉碏碣碓碔碞碡碪碫碬砀碯碲砜碻礴磈磉磎硙磔磕磖磛磟磠磡磤磥蹭磪磬磴磵磹磻硗礀硚礅礌礐礚礜礞礤礧礮砻礲礵礽礿祂祄祅祆禳祊祍祏祓祔祕祗祘祛祧祫祲祻祼饵脔锢禂禇禋祦禔祎隋禖禘禚禜禝禠祃禢禤禥禨禫祢禴禸秆秈秊闱飒秋秏秕笈蘵赁秠秣秪秫秬秭秷秸稊稌稍稑稗稙稛稞稬秸稲稹稼颡稿穂穄穇穈穉穋稣贮穏穜穟秾穑穣穤穧穨穭穮穵穸窿阒窀窂窅窆窈窕窊窋窌窒窗窔窞窣窬黩蹙窑窳窴窵窭窸窗竁竃竈竑竜并竦竖篦篾笆鲛竾笉笊笎笏笐靥笓笤箓笪笫笭笮笰笱笲笳笵笸笻筀筅筇筈筎筑筘筠筤筥筦笕筒筭箸筰筱筳筴宴筸箂个箊箎箑箒箘箙箛箜篌箝箠箬镞箯箴箾篁筼筜篘篙篚篛篜篝篟篠篡篢篥篧篨篭篰篲筚篴篶篹篼箦簁簃簆簉簋簌簏簜簟簠簥簦簨簬簰簸簻籊藤籒籓籔签籚篯箨籣籥籧笾簖籫籯芾麴籵籸籹籼粁秕粋粑粔粝粛粞粢粧粨粲粳稗粻粽辟粿糅糆糈糌糍糒糔萼糗蛆蹋糢糨糬粽糯糱籴粜糸糺紃蹼鲣霉纡纨绔纫闽襻紑纰纮锭鸢鹞纴紞紟扎紩紬绂绁纻紽紾绐絁絃絅経絍绗絏缡褵絓絖絘絜绚絣螯絪絫聒絰絵绝絺絻絿綀绡綅绠绨绣綌綍綎捆綖綘継続缎绻綦綪线綮綯绾罟蝽綷縩绺绫緁绲緅緆缁绯緌緎総緑绱緖缃缄缂绵缗緤褓缌纂緪緰缑缈缏缇縁縃縄萦缙缒縏缣縕缞縚缜缟缛縠縡縢縦绦縯縰骋缧縳纤缦絷缥縻衙縿繄缫繈繊繋繐缯繖繘繙繠缋繣繨缰缲繸繻缱纁纆纇缬缵纩纑纕缵纙纚纛缾罃罆坛罋罂罎罏罖罘罛罝罠罣罥罦罨罫罭锾罳罶罹罻罽罿羂羃羇芈蕉51鸵羑羖羌羜羝羢羣羟羧羭羮羰羱羵羶羸藜鲐翀翃翅翊翌翏翕翛翟翡翣翥翦跹翪翫翚翮翯翱翽翾翿板饕鸹锨耋耇耎耏专耒耜耔耞耡耤耨耩耪耧耰鬓耵聍聃聆聎聝聡聦聱聴聂聼阈聿肄肏肐肕腋肙肜肟肧胛肫肬肭肰肴肵肸肼胊胍胏胑胔胗胙胝胠铨胤胦胩胬胭胯胰胲胴胹胻胼胾脇脘脝脞脡脣脤脥脧脰脲脳腆腊腌臜腍腒腓胨腜腠脶腥腧腬腯踝蹬镣腴腶蠕诽膂腽嗉膇膋膔腘膗膙膟黐膣膦膫膰膴膵膷脍臃臄臇臈臌臐臑臓膘臖臙臛臝臞臧蓐诩臽臾臿舀舁鳑鲏舋舎舔舗馆舝舠舡舢舨舭舲舳舴舸舺艁艄艅艉艋艑艕艖艗艘艚艜艟艣舣艨艩舻艬艭荏艴艳艸艹艻艿芃芄芊萰陂藭芏芔芘芚蕙芟芣芤茉芧芨芩芪芮芰鲢芴芷芸荛豢芼芿苄苒苘苙苜蓿苠苡苣荬苤苎苪镑苶苹苺苻苾茀茁范蠡萣茆茇茈茌茍茖茞茠茢茥茦菰茭茯茳藨茷藘茼荁荄荅荇荈菅蜢鸮荍荑荘豆荵荸荠莆莒莔莕莘莙莚莛莜莝莦莨菪莩莪莭莰莿菀菆菉菎菏菐菑菓菔芲菘菝菡菢菣菥蓂菧菫毂蓥菶菷菹醢菺菻菼菾萅萆苌萋萏萐萑萜萩萱萴莴扁萻葇葍葎葑荭葖葙葠葥苇葧葭药葳葴葶葸葹葽蒄蒎莼茏薹莅蒟蒻蒢蒦蒨蒭藁蒯蒱鉾蒴蒹蒺蒽荪蓁蓆蓇蓊蓌蓍蓏蓓蓖蓧蓪蓫荜跣藕苁蓰蓱莼蓷蓺蓼蔀蔂蔃蔆蔇蔉蔊蔋蔌蔎蔕蔘蔙蒌蔟锷蒋雯茑蔯蔳麻蔵蔸蔾荨蒇蕋蕍荞蕐蕑芸莸蕖蕗蕝蕞蕠蕡蒉蕣蕤蕨蕳蓣蕸蕺蕻薀薁薃薅薆荟薉芗薏薐蔷薖薘剃谔钗薜薠薢薤薧薨薫薬薳薶薷薸薽薾薿藄藇藋荩藐藙藚藟藦藳藴苈藷藾蘀蘁蕲苹蘗蘘蘝蘤蘧蘩蘸蘼虀虆虍蟠虒虓虖虡虣虥虩虬虰蛵蛇虷鳟虺虼蚆蚈蚋蚓蚔蚖蚘蚜蚡蚣蚧蚨蚩蚪蚯蚰蜒蚱蚳蚶蚹蚺蚻蚿蛀蛁蛄蛅蝮蛌蛍蛐蟮蛑蛓蛔蛘蛚蛜蛡蛣蜊蛩蛱蜕螫蜅蚬蜈蝣蜋蜍蜎蜑蠊蜛饯蜞蜣蜨蜩蜮蜱蜷蜺蜾蜿蝀蝃蝋蝌蝍蝎蝏蝗蝘蝙蝝鲼蝡蝤蝥猿蝰虻蝲蝴蝻螃蠏蛳螉螋螒螓螗螘螙螚蟥螟螣螥螬螭䗖螾螀蟀蟅蝈蟊蟋蟑蟓蟛蟜蟟蟢虮蟨蟪蟭蛲蟳蛏蟷蟺蟿蠁蠂蠃虿蠋蛴蠓蚝蠗蠙蠚蠛蠜蠧蟏蠩蜂蠮蠰蠲蠵蠸蠼蠽衁衄衄衇衈衉衋衎衒同衖胡衞裳钩衭衲衵衹衺衿袈裟袗袚袟袢袪袮袲袴袷袺袼褙袽裀裉袅裋夹裍裎裒裛裯裱裲裴裾褀褂褉褊裈褎褐褒褓褔褕袆褚褡褢褦褧褪褫袅褯褰褱裆褛褽褾襁褒襆裥襉襋襌襏襚襛襜裣襞襡襢褴襦襫襬襭襮襕襶襼襽襾覂覃覅霸覉覊覌覗觇覚覜觍觎覧覩觊觏覰観觌觔觕觖觜觽觝觡酲觩觫觭觱觳觯觷觼觾觿言赅讣訇訏訑訒诂讬訧訬訳訹证訾詀詅诋毁詈詊讵詑诒诐詗诎察詨诜詶詸詹詻诙诖誂誃诔锄诓誋诳诶悖誙诮诰誧説読誯谇訚谄谆諆諌诤诹诼諕谂谀諝谝諟喧谥諴諵谌谖誊謆謇歌謍謏謑谡谥謡謦謪谪讴謷謼谩哗譅譆譈譊讹譒撰谮鑫譞噪譩谵譬譱譲谴譸譹谫讅讆詟䜩雠讐谗谶讙谠讟谽豁豉豇岂豊豋豌豏豔豞豖豗豜豝豣豦豨豭豱豳豵豶豷豺豻貅貆狸猊貔貘䝙貜貤餍贳餸贶贲赂賏赊赇赒賝赓赕賨赍斗賮賵賸赚赙赜赟贉赆赑贕赝赬赭赱赳迄趁趂趄趐趑趒趔趡趦趫趮趯趱趴趵趷趹趺趿跁跂跅跆踬跄跐跕跖跗跙跛跦跧跩跫跬跮跱跲跴跺跼跽踅踆踈踉踊踒踖踘踜踟躇蹰踠踡踣踤踥踦踧跷踫踮逾踱踊踶踹踺踼踽躞蹁蹂躏蹎蹐蹓蹔跸蹚蹜蹝迹蹠蹡蹢跶蹧蹩蹪蹯鞠蹽躃躄躅踌跻躐踯跞躘躙躗躝躠蹑躜躧躩躭躰躬躶軃軆辊軏轫軘軜軝腭転軥軨軭軱轱辘軷轵轺軽軿輀輂辇辂辁輈挽輗辄辎辋輠輤輬輭輮辏輴輵輶輹輼辗辒轇轏轑轒辚轕轖轗轘轙轝轞轹轳罪辣辞辵辶辺込辿迅迋迍麿迓迣迤逦迥迨迮迸迺迻迿逄逅逌逍逑逓迳逖逡逭逯逴逶逹遄遅侦遘遛遝遢遨遫遯遰遴绕遹遻邂邅邉邋邎邕邗邘邛邠邢邧邨邯郸邰邲邳邴邶邷邽邾邿郃郄郇郈郔郕郗郙郚郜郝郞郏郠郢郪郫郯郰郲郳郴郷郹郾郿鄀鄄郓鄇鄈鄋鄍鄎鄏鄐鄑邹邬鄕郧鄗鄘鄚鄜鄞鄠鄢鄣鄤鄦鄩鄫鄬鄮鄯鄱郐鄷鄹邝鄻鄾鄿酃酅酆酇郦酊酋酎酏酐酣酔酕醄酖酗酞酡酢酤酩酴酹酺醁醅醆醊醍醐醑醓醖醝酝醡醤醨醪醭醯醰酦醲醴醵醸醹醼醽醾釂酾酽釆釈鲈镏阊钆钇钌钯钋鼢鼹钐钏釪釬釭釱钍釸钕钫鈃钭鈆鈇钚鈊鈌钤钣鈒鈤钬钪鈬铌铈钶铛钹铍钸钿鉄鉆铊铇鉌铋鉏铂钷铆钵鉥钲鉨钼钽鉱鉲鉶铰铒鉼铪銍銎铣銕镂铫铦铑铷銤铱铟銧铥铕铯銭銰焊銶锑锉汞鋂锒鋆鋈鋊铤鋍铗鋐鋑鋕鋘鋙锊锓锔锇铓鋭铖锆锂铽鋳鋹鋺鉴镚钎錀锞锖锫锩錍铔锕錔锱铮锛錞锬锜錤錩錬録铼錼锝钔锴鍉镀鍏鍐铡鍚锻锽锸锲锘鍫鍭鍱鍴锶鍹锗针锺锿镅鎉鎋鎌鎍鎏鎒鎓鎗镉鎚鎞镃鎤铩锼鎭鎯镒镍鎴镓��鎹镎镟鏊镆镠镝鏖铿锵鏚镗镘镛鏠鏦錾镤鏸镪鏻鏽鏾铙鐄鐇鐏铹镦镡鐗馗镫镢镨鐡锎镄鐩镌鐬鐱镭鐶鐻鐽镱鑀鑅镔鑐鑕鑚鑛鑢鑤镥鑪镧鑯鑱鑴鑵镊镢钃镻闫闬闶闳閒闵閗閟阂関合閤哄阆閲阉閺阎阏阍阌暗闉阕阗闑闒闿闘闚阚闟闠闤闼阞阢阤阨阬阯阹阼阽陁陑陔陛陜陡陥陬骘陴険陼陾阴隃隈隒隗隞隠隣隤隩隮隰颧隳隷隹雂雈雉雊雎雑雒雗雘雚雝雟雩雰雱驿霂霅霈霊沾霒霓霙霝霢霣霤霨霩霪霫霮靁叇叆靑靓靣腼靪靮靰靳靷靸靺靼靿鞀鞃鞄鞍鞗鞙鞚鞝鞞鞡鞣鞨鞫鞬鞮鞶鞹鞾鞑韅鞯驮韍韎韔韖韘韝韫韡韣韭韭韱韹韺頀刮頄顸顼頍颀颃颁頖頞頠頫頬颅頯頲颕頼悴顋顑颙颛颜顕顚顜颟顣颥颞飐飑台飓颸飏飖颽颾颿飀飂飚飌翻飡飣饲飥饨饫飮飧飶餀餂饸饹餇餈饽哺馂餖餗餚馄馃餟餠餤餧餩餪餫糊餮糇餲饧馎糕饩馈馊馌馒饇馑馓膳饎饐饘饟馕馘馥馝馡馣骝骡馵馹駃駄駅駆駉駋驽駓驵駗骀驸駜骂骈駪駬骃駴骎駹駽駾騂騄骓騆騉騋骒骐麟騑騒験騕骛騠騢騣騤騧骧騵驺骟騺蓦骖骠骢驆驈骅驌骁驎骣驒驔驖驙驦驩驫骺鲠骫骭肮骱骴骶骷髅骾髁髂髄髆膀髇髑髌髋髙髝髞髟髡髣髧髪髫髭髯髲髳髹髺髽髾鬁鬃鬅鬈鬋鬎鬏鬐鬑鬒鬖鬗鬘鬙鬠鬣斗鬫鬬阄鬯鬰鬲鬵鬷魆魈魊魋魍魉魑魖鳔魛魟魣魦魨魬鲂魵魸鮀鲅鮆鲧鲇鲍鲋鮓鲒鲕鮟鱇鮠鮦鮨鲔鲑鮶鮸鮿鲧鯄鯆鲩鯈鲻鯕鲭鲞鯙鯠鲲鯥鲰鲶鳀鯸鳊鲗䲠鹣鳇鰋鳄鳆鰕鰛鰜鲥鰤鳏鰦鳎鳐鳁鳓鰶鲦鲡鰼鰽鱀鱄鳙鱆鳕鱎鱐鳝鳝鳜鲟鲎鱠鳣鱨鲚鱮鱲鱵鱻鲅鳦凫鳯鳲鳷鳻鴂鴃鴄鸩鴈鴎鸰鴔鴗鸳鸯鸲鹆鸱鴠鴢鸪鴥鸸鹋鴳鸻鴷鴽鵀鵁鸺鹁鵖鵙鹈鹕鹅鵟鵩鹌鵫鵵鵷鵻鹍鶂鶊鶏鶒鹙鶗鶡鶤鶦鶬鶱鹟鶵鶸鶹鹡鶿鹚鷁鷃鷄鷇䴘䴘鷊鷏鹧鷕鹥鸷鷞鷟鸶鹪鹩鷩鷫鷭鹇鹇鸴鷾䴙鸂鸇䴙鸏鸑鸒鸓鸬鹳鸜鹂鹸咸鹾麀麂麃麄麇麋麌麐麑麒麚麛麝麤麸面麫麮麯麰麺麾黁黈黉黢黒黓黕黙黝黟黥黦黧黮黰黱黪黶黹黻黼黾鼋鼂鼃鼅鼈鼍鼏鼐鼒冬鼖鼙鼚鼛鼡鼩鼱鼪鼫鼯鼷鼽齁齆齇齈齉齌赍齑龀齕齗龅齚龇齞龃龉龆齢出齧齩齮齯齰齱齵齾厐龑龒龚龖龘龝龡龢龤'
18
+
19
+ traditional_characters = '制咖片型超聲盤鑒定仔點他命書歌粉巾字帳恤手指記憶棒形轉彎溝光○〇㐄㐅㐆㐌㐖毒㐜㐡㐤㐰㐺㑇㑳㒳㒸㔾㗂㗎㝵㞎㞙㞞㠯㢲㢴㤅㥁㥯㨗㫺㬎㮎㮚㮸㲋㲱㲾㳮㵎㵪㶸㷖㷭㹢㹴犬㺢狓㺵㼝㽮㿝䍃䔢䖟䖸䗈䗥䗪䝓䠶䥯䦉䯝䰾魚䲔䳗䳘䵹鼄䶑一對應映射丁不識下兒子做二休世丘之貉並中台原則串為甚謂乾淨了百事無成八變五十些人得道雞升天代如併來去個國政策勁幽靈在歐洲遊蕩接樣蘿蔔坑側化傳價元論醇共再准刀兩斷切分耕耘收穫錢貨物向看舊就緒險刻千金動勞永逸匙零夜半卡通回復返影蹤反常態口咬氣句話同吐快吹周味呼諾嗚品紅鍋哄而散起唱和問三知生熟團漆黑火糟堆場空塊麵塌糊塗塵染壁廂夔已足多情露水大早到晚夫妻當關萬莫開失古恨套所料既往孔見提師要家主審寸陰難買鬥牛小撮部陣局展身層巴掌帆風順席地帶過年計於春頭載四季期被蛇怕井繩度願式份彈頃深前律徑心意念差愁孤行俱全房廳交遮打技長把抓死拿眼淚鼻涕鑰鎖折段抿拍即合掃排掬揮撥擁上入擊洞擲攬改故轍敗文值名斑方面旁族日秋餐隔雅里終父旦時晌會霎間晃暴寒曝更月望垠際朝夕本正經利杯羹東西板枝獨秀根筋桿進條龍服務概模次函數又性程總付步腳印趨登毛拔呵氧氮碳決雌雄波未平派謊言流清楚白準溜煙潭有獲聞是處降琴鶴甲病發可拾沙目然瞭直以相眨穿睹瞥瞬矢的解石鳥神教秉虔誠秘種窩蜂窮竅笑置筆苟勾銷抹殺煞等獎箍節吃箭仇雙鵰詩籌籮筐系列紙級士官統絲毫掛維網盡線微吭響股腦胎脈承腔臂力致效資源址器舉功投般說講規貿易葉障著慎滿皆輸號木電池衣傾鐘高低視仁覺醒覽遺角銀幣觸潰九鼎蔽抄出駟馬追重語破貧洗貫走路安蹴至幾蹶振躍役膽汗較輩輪辭贊退六連遍遞邊針血錘音錯門思閃真倒項栽霧類保護川先驚乍體鬨鱗爪鳴滴泡鄰域黨專鼓作齊炒丑烯亥克內酯冬加奴卯肝炎基尺梁街褲鎬客寵庭巳汝昌烷玲磊糖肇酉醛啷青縣韙良香骨鯛丂七集河市弦喜嘴張舌堵區工業姊妹星架構巧彩扭歪拼湊餘熱曜武州爺浮屠美鄉老階樹葷素碎落能魄鰓鰻珠丄丅丆万俟丈尚摸母娘量管群亞虎必我堂令申件裝伏位博俠義界表女墟臺戲臭皮匠勝諸葛亮賽頂倍催請運算包立叉戟離疫苗土史志演圍揭瓦曬夷姑婆帝村寶爛尖杉鹼屜桌山岔島由紀峽壩庫鎮廢從德後拗湯治旬食明昧曹朋友框欄極權冪曲歸依貓民氟硼氯磷鐵江侗自旅法司洋浦梅園溫暖灣焦班幸用田略番疊皇炮捶硝苯酸腺苷稜草鏡穗跳遠索錦綱聚氰胺聯店胚膲愛色堇紫羅蘭芝茶飯菱雲蟲藏藩亂叛蘇親債凳學座恐戀柱測肌腹衩錐係貂企烏跪叩軍車農題迭都甘油屯奏鍵短阿姨陪姐隻顧茅廬槽駕魂鮮鹿頁其菜單乘任供勢午齒漢組織吊調瀉唇坡城報墳外夸將尉建築岸崗公床揚新劍昇杭林栗校樓標款汽社浣海商館劇院鋼華港機械廣媒環球融第醫科證券綜財樂育游漲猶嶺疏癮瞼確兵領導繳肢膛船艾瑟爾蒼蔡虞傚衫覆訪訴課諭議軌述野鉤限敵鞋頜頷顎饒首齦站例修凡劃垂屆屬崽頦廚拜挫擺放旋削棋榻檻禮沉注滑營獄畫确儀聘花葬詔員跌轄週達酒錨閘陷陸雨雪飛威丌于丹久乏予理評產亢卑亦乎舞己悲矩圓詞害誌但住佞佳便俗信票案幅翁倦倫假偏倚斜虧鬼敲停備傷脾胃僅此像儉匱免宜穴焉戴兼容許凍伯仲負彼晝皂軒輊實刊划顛衛戰哥比省非好黃飾別拘束掩奶睬選擇搖擾煩苦枚寫協厭及格受歡迎約只估侵犯割狀告或缺抗拒挽撤救藥喻磨滅端倪少逆逾越避靠適吉譽吝玉含延咎歹聽啻淵善謀均勻堪忍夠太惹妙妥妨孕症孝術室完納推冠積宣疑辯慄碴稱屈撓屑干涉衡待很忙惡忿怎麼怠急恥恭息悅惑惜惟想愉愧怍慌憤啟懂懈懷材才緊招認扣抵拉捨也罷插揣冒搭撞南牆擴核支攻敢雷攀敬裡嗎需景智暇曾罪遇朽枉止況競爭辱求癒渝溶濟左右袒困補爽特寂寞示弱找謝畏強疾徐痛癢冤符眠睦瞅董何厚云措活疲羞者輕玻璃祥兆禁移稂莠穩佛換答簡結果盟絕縷途給談否羈翼耐肖脛毋寧興舒若菲萊痕跡窠臼虛衰臉兔撒鷹棺範該詳諱抬泰讓鬚眉象眾貲賬費灰賴奇慮訓輟辨菽麥辛近送透逞徒速續逮捕遂遑違遜斧鉞艱醉鏽隨觀棄顯飽脂肪使丏丐幫丒且慢末丕替桃宗王尊涼爵各圖屋脊糧署錄壇吾祿職胄襲君廈丗北壑桐疹損逢陵鷸丙寅戌氨腈唑綸辰酮脫氫酶醚丞丟現掉紗帽弄扯砲碗丠両丣坐存激肩臻蒂蓮悖序驅丨丩丫挺杈髻鬟細介俄伊犁京尼布訂普渡央委監察檢查劑圈設警隊斯督剩震境航舶革防托播促質版蠑螈鋒研藝歷殘消頻譜精密製造陲郵候埔堅壓壢凹匯執府究邦俘攝寮彬狼嶽肺腫庸英訊診埋粒胞括控碼韓暑槍樞砥澳哇牟壽甸鑽探篇簽綴縫繼耳肯照婦埃懸璧軸櫃檯辣擱淺邪跑纖阮陽私囊魔丮丰姿采丱燒丳丵丶丷丸參寨朗桂瑞砂衷霞貌鳳僕艦因嫌宰峰幹絡牌持旨祭禱簿編罰賓辦丼丿乀乂乃乄仰慕盛曠留考驗闊乆乇么醜麼乊湖燃乑乒乓乕乖僻忤戾离謬迕乗危肥劫除隙浪婿乙炔腸酰吡咯鹽乚乛乜嘢卿玄宮尾狐龜塔嶷兄弟泉章霄釘耙乞扎哀憐恕討乢乣乤乥乧乨乩童乪乫乭乳暈汁液瑤漿牙癌突竇罩腐膠豬酪蛋糕菌瘤乴乵乶乷乸乹乺乼乾俸冰嘉噦嚎坤媽屍壘旱枯涸俐渴潮澀煸豆燥爹瘦癟癬瞪袋脆薑貝隆餾乿亀亁叫咕攘扔搞男砸竄蓬麻亃亄亅卻亇遲典今臨繁累卵奉婚聰躬巨與遷添裂副宿歲怪噁尕崙愣杆硅硫鈦鈾錳芑雜異鈉砷胂磺琥珀艙棍簧胡茬盜浩盆販郎腿亍洪亐互欠助勉惠操斥諉繫戶譯亓墓碑刑鈴卅渠繽紛斗米旗憲釩燈徽瘟祖拳福穀豐臟腑綁肉醃苓蘊橋鋪霸顏鬧判噴岡底蛙陘礦亖亙亜罕們娜桑那努哈喀弗烈曼松森杜氏盃奧琛敦戊穆聖裔彙薛孫亟亡佚虜羊牢奮釋卷卸契媾感額睫纏誼趾塞擠紐阻還配馳莊亨洛祚亪享津滬畿郊慈菴枇杷膏亭閣鋥麗亳亶亹誅初責翻瘋偶傑叢稠妖拖寰居吸授慧蝸吞壯魅狗矛盾益渣患憂稀描猿夢暫涯畜禍緣沸搜引擎臣橫紜誰混援蒸獸獅稅剖亻亼亽亾什獻剎邡麽仂仃仄仆富怨仈仉畢昔晨殼紹仍仏仒仕宦仗欺恃腰嘆歎炬梓訖施仙后瓊逝仚仝仞仟悔仡佬償填泊拓撲簇羔購頓欽佩髮棻閫馭養億儆尤藉幀賑凌敘帖李柔剛沃眥睚戒訛取饗讀仨仫仮著泳臥躺韶夏裁仳仵唯賢憑釣誕仿似宋彿諷伀碩盼鵝伄儅伈伉儷柯始娃邁戈坦堡帕茨薩廟瑪莉莎藤霍姆伋伍奢胥廷芳豪伎倆侍汛勒希羲雛伐憩整謨閑閒伕伙伴頤伜伝伢叔恆茲恩翰伱伲侶伶俜悧鼬伸懶縮喇叭伹伺伻伽倻輻伾佀佃佇佈喬妮墨佉盧佌貸劣廉昂檔濃矮傘窪緩耗胸谷迷擋率齲宅沫舍療佐貳佑佔優據鏵嘗呢須魯曉佗佘余坪寺瓜銃僧蒙芒陀龕哼嘔坊姦孽弊揖祟繭縛誓賊佝僂瞀佟你奪趕佡佢佣佤佧賈佪佫佯佰佱潔績釀餚佴捲佶佷佸佹佺佻佼佽佾具喚窘壞娛怒慨硬習慣聾膨脹蔓駭貴痺侀侁侂侃侄侅鴻燕侇侈糜靡侉侌妾侏儒倉鼠侐侑侔侖侘侚鏈侜偎傍鈷循柳葫蘆附価侮罵蔑侯岩截蝕侷貼壺嬛宴捷攜桶箋酌俁狹膝狄俅俉俊俏俎俑俓俔諺俚俛黎健呈固墒增守康箱濕祐鏢鑣槓盒靖膜齡俞豹獵噪孚封札筒託衍鴿剪撰稿煉廠禊練繕葺俯瞰撐衝俲俳俴俵俶俷俺俻俾倀倂倅儲卒惶敷猝逃頡蓄崇隱倌倏忽刺蠟燭噍嚼坍扁抽斃蔥楣灌灶糞背藪賣賠閉霉騰倓倔倖倘倜儻倝借箸挹澆閱倡狂倢倣値倥傯倨��倩匡嗣沖柝珍倬倭寇猩倮倶倷倹勤讚偁偃充偽吏嗓寐惺扮拱芫茜藉虢鈔偈偉晶偌宕距析濾殿疼癱註頗偓偕鴨歇滯偝偟偢忘怡旺偨偩偪偫偭偯偰偱偲偵緝蹄偷減惰漏窺竊偸偺迹傀儡傅傈僳傌籬傎奎琳迪叟芭傒傔傕傖悉荒傜傞傢傣芽逼傭婢傮睨寄檄誦謠頌傴擔辜弓慘蒿悼疤傺傻屄臆巢洩篋羨蓋軋頹傿儸僄僇僉僊働僎僑僔僖僚僝僞僣僤僥僦猴僨僩僬僭僮僯僰僱僵殖籤靜僾僿征隴儁儂儃儇儈朴薄儊儋儌儍儐儓儔儕儗儘儜儞儤儦儩汰哉寡渥裕酷儭儱罐儳儵儹儺儼儽兀臬臲鷲允勛勳宙宵帥憝彞諧嫂鬩暢沛溢盈飢赫兇悍狠猛頑愚妣斬秦遣鞭耀敏榮槃澤爆碟磁禿纜輝霽鹵朵婁孜烽醬勃汀箕裘鉗耶懞蕾徹兌軟遭黜兎児韻媳爸兕觥兗兙兛兜售鍪肚兝兞兟兡兢兣樽殮涅睡稟籍贅泌啡肽奸幕涵澇熵疚眷稃襯訌赴煥椒殲植跏沒試誤猜棲窗肋袖頰兪卦撇鬍岐廓轎疸楓茴瓏廁秩募勺噸寓斤曆畝迫筷釐最淫螺韜兮寬匪篩襄贏軛複兲詐刃堰戎痞蟻餉它冀鑄冂冃円冇冉冊嫁厲礪竭醮冏牧冑冓冔冕冖冗冘冞冢窄抑誣冥冫烘菇蟄冷凝坨橇淇淋炭餅磚磧窖醋雕雹霜冱冶爐艷嘲峻灘淡漠煖颼飲冼冽凃凄愴梗凅凇凈凊凋敝濛凔凜遵汞脢凞几凢処凰凱凵凶焰凸摺刷紋預喪嘍奔巡榜殯芙蓉租籠輯鞘萃凼鋸鑊刁蠻刂娩崩批拆攤掰櫱驟歧顆秒袂贓勿囑忌磋琢膚刈羽刎訟戮舂槳艇刓刖霹靂刜創犢刡恙墅幟筵緻刦刧刨昏默攸尿慾薰潤薰圭刪刮痧鏟刱刲刳刴刵踏磅戳柏槐繡芹莧蝟舟銘鵠鶩刼剁剃辮剄剉履鉛剋剌姻咽哨廊掠桅沿召瞻翅趙卜渺茫郭剒剔剕瀝剚愎毅訥纔剜剝啄採剞剟剡剣剤綵剮腎駛黏剰袍剴紊剷剸剺剽剿劁劂劄劈啪柴扳啦劉奭姥夼昫涓熙禪禹錫翔雁鶚劊劌弩柄蜻蛉劒劓劖劘劙瀾簣賞磯釜晉甜薪逐劦熔紂虐赤囚劬劭労劵効劻劼劾峭艮勅勇勵勍勐臘脖龐漫飼盪粥輒勖勗勘驕餒碌泮雇捐竹騎殊阱勣樸懇謹勦勧勩勯勰勱勲勷勸懲慰誡諫勹芡踐闌匁庇拯粟紮袱裹餃匆遽匈匉匊匋匍匐莖匏匕妝痰膿蛹齋苑烤蹈塘羌熊閥螳螂疆碚竿緯荷茵邙魏匚匜匝匟扶稷匣匭攏匸匹耦匽匾匿卂叮瘡禧軫堤棚迢鈞鍊卄卆遐卉瓷盲瓶噹胱腱裸卋卌卍卐怯污賤鄙齷齪陋卓溪唐梯漁陳棗泥漳潯澗梨芬譙贍轅迦鄭単驢弈洽鰲卛占筮卝卞卟吩啉屎翠厄卣卨卪卬卮榫襖璽綬鈕蚤懼殆篤聳卲帘帙繞卹卼卽厂厎厓厔厖厗奚厘厙厜厝諒厠厤厥厪膩孢厮厰厳厴厹厺粕垢蕪菁厼厾叁悟茸薯叄吵笄悌哺譏坫壟弧芯杠潛嬰芻袁詰貪諜煽饋駁収岳締災賄騙叚叡吻攔蘑蜜訣燧玩硯箏椎藺銅逗驪另覓叨嘮謁杵姓喊嚷囂咚嚀塑尋惱憎擦祇泣滲蝠叱吒咄咤喝籀黛舵舷叵叶鐸懿昭穰苴遼叻叼吁塹嫖賭瞧爬衆抒吅吆夥巹橡滌抱縱摩郡唁墜扇籃膀襪頸吋愾諮酬哭妓媛暗錶韁邇妃羿絮蕃渾拐葵暮隅吔吖啶嗪戚吜嗇噬嚥吟哦詠吠吧唧嗒咐吪雋咀徵燐苞茹鈣哧吮吰吱嘎吲哚吳棟嬌窟孟簫忠晗淞闔閭趼宇吶睛噓拂捧疵熄竽笛糠吼吽呀呂韋矇呃呆笨呇貢呉罄呋喃呎呏呔呠呡癡呣呤呦呧瑛眩扒晬淑姬瑜璇鵑呪呫嗶嚅囁呬呯呰呱呲咧噌鈍呴呶呷呸呺呻哱咻嘯嚕籲坎坷邏呿咁咂咆哮咇咈咋蟹煦珅藹咍咑咒詛咔噠嚓咾噥哩喱咗咠咡咢咣咥咦咨嗟詢咩咪咫嚙齧咭咮咱咲咳嗆嗽咴咷咸咹咺咼喉咿婉慟憫賦矜綠茗藍哂搶瞞哆嗦囉噻啾濱彗哋哌哎唷喲哏哐哞哢哤哪裏哫啼喘哰哲萎蚌哳哶哽哿唄唅唆唈唉唎唏嘩堯棣殤璜睿肅唔睇唕唚唞唣喳唪唬唰喏唲唳唵嘛唶唸唹唻唼唾唿啁啃鸚鵡啅埠棧榷祺舖鞅飆啊啍啎啐啓啕啖啗啜啞祈啢啣啤啥啫啱啲啵啺饑啽噶崑沁喁喂喆裙喈嚨喋喌喎喑喒喓喔粗喙幛慶滋鵲喟喣喤喥喦喧騷喨喩梆喫葡萄喭駝挑嚇碰樅瓣純皰藻趟鉻喵営喹喺喼喿嗀嗃嗄嗅嗈嗉嗊嗍嗐嗑嗔詬嗕嗖嗙嗛嗜痂癖嗝嗡嗤嗥嗨嗩嗬嗯嗰嗲嗵嘰嗷嗹嗾嗿嘀嘁嘂嘅惋嘈峪禾蔭嘊嘌嘏嘐嘒嘓嘖嘚嘜嘞嘟囔嘣嘥嘦嘧嘬嘭這謔嚴敞饞鬆嘵嘶嘷嘸蝦嘹嘻嘽嘿噀噂噅噇噉噎噏噔噗噘噙噚噝噞噢噤蟬皿噩噫噭噯噱噲噳嚏涌灑欲巫霏噷噼嚃嚄嚆抖嚌嚐嚔囌嚚嚜嚞嚟嚦嚬嚭嚮嚯嚲嚳飭按竣苛嚵嚶囀囅囈膪謙囍囒囓囗囘蕭酚飄濺諦囝溯眸紇鑾鶻囟殉囡団囤囥囧囨囪囫圇囬囮囯囲図囶囷囸囹圄圉擬囻囿圀圂圃圊粹蠹赦圌墾圏滾鯡鑿枘圕圛圜圞坯埂壤骸炕祠窯豚紳魠鯪鱉圧握圩圪垯圬圮圯炸岬幔毯祇窨菩溉圳圴圻圾坂坆沾坋坌舛壈昆墊墩椅坒坓坩堝坭坰坱坳坴坵坻坼楊掙涎簾垃垈垌垍垓垔垕垗垚垛垝垣垞垟垤垧垮垵垺垾垿埀畔埄埆埇埈埌殃隍埏埒埕埗埜埡埤埦埧埭埯埰埲埳埴埵埶紼埸培怖樁礎輔埼埽堀訶姪廡堃堄摧磐貞韌砌堈堉堊堋堌堍堎堖堙堞堠礁堧堨輿堭堮蜓摘堲堳堽堿塁塄塈煤塋棵塍塏塒塓綢���鴉沽虱塙塚塝繆塡塢塤塥塩塬塱塲蟎塼塽塾塿墀墁墈墉墐夯増毀墝墠墦漬缽墫墬墮墰墺墻櫥壅壆壊壌壎壒榨蒜壔壕壖壙壚壜壝壠壡壬壭壱売壴壹壻壼寢壿夂夅夆変夊夌漱邑夓腕泄甥禦骼夗夘夙袞瑙妊娠醣梟珊鶯鷺戧幻魘夤蹀祕擂鶇姚宛閨嶼庾撻拇賛蛤裨菠氅漓撈湄蚊霆鯊箐篆篷荊肆舅荔鮃巷慚骰辟邱鎔鐮阪漂燴鯢鰈鱷鴇臚鵬妒峨譚枰晏璣癸祝秤竺牡籟恢罡螻蠍賜絨御梭夬夭砣榆怙枕夶夾餡奄崛葩譎奈賀祀贈奌奐奓奕訢詝奘奜奠奡奣陶奨奩魁奫奬奰媧孩貶隸酥宄狡猾她奼嫣妁氈荼皋膻蠅嬪妄妍嫉媚嬈妗趣妚妞妤礙妬婭妯娌妲妳妵妺姁姅姉姍姒姘姙姜姝姞姣姤姧姫姮娥姱姸姺姽婀娀誘懾脅娉婷娑娓娟娣娭娯娵娶娸娼婊婐婕婞婤婥谿孺婧婪婬婹婺婼婽媁媄媊媕媞媟媠媢媬媮媯媲媵媸媺媻媼眯媿嫄嫈嫋嫏嫕嫗嫘嫚嫜嫠嫡嫦嫩嫪毐嫫嫬嫰嫵嫺嫻嫽嫿嬀嬃嬅嬉耍嬋痴豔嬔嬖嬗嬙嬝嬡嬢嬤嬦嬬嬭幼嬲嬴嬸嬹嬾嬿孀孃孅孌孏曰癲屏孑孓雀孖斟簍謎摺孛矻鳩崮軻祜鸞孥邈毓棠臏孬孭孰孱孳孵泛罔銜孻孿宀宁宂拙株薇掣撫琪瓿榴謐彌宊濂祁瑕宍宏碁宓邸讞実潢町宥宧宨宬徵崎駿掖闕臊煮禽蠶宸豫寀寁寥寃簷庶寎暄磣寔寖寘寙寛寠苫寤肘洱濫蒗陝覈寪弘綽螽寳擅疙瘩晷対檐専尃尅贖絀繚疇釁尌峙醌襟痲碧屁昊槌淘恵瀑牝畑莓缸羚覷蔻髒躁尒尓銳尗尙尜尟尢尥尨尪尬尭尰擒尲尶尷尸尹潽蠖蛾尻釦梢蚴鰭脬蹲屇屌蚵屐屓挪屖屘屙屛屝屢屣巒嶂巖舄屧屨屩屪屭屮戍駐鉀崖嵛巔旮旯楂欖櫸芋茱萸靛麓屴屹屺屼岀岊岌岍阜岑彭鞏岒岝岢嵐岣岧岨岫岱岵岷峁峇峋峒峓峞峠嵋峩峯峱峴峹峿崀崁崆禎崋崌崍嶇崐崒崔嵬巍螢顥崚崞崟崠崢巆崤崦崧殂崬崱崳崴崶崿嵂嵇嵊泗嵌嵎嵒嵓嵗嵙嵞嵡嵩嵫嵯嵴嵼嵾嶁嶃嶄晴嶋嶌嶒嶓嶔嶗嶙嶝嶞嶠嶡嶢嶧嶨嶭嶮嶰嶲嶴嶸巂巃巇巉巋巌巓巘巛滇芎巟巠弋迴巣巤炊擘蜥蟒蠱覡巰蜀彥淖杏茂甫楞巻巽幗巿帛斐鯽蕊帑帔帗帚琉汶帟帡帣帨帬帯帰帷帹暆幃幄幇幋幌幏幘幙幚幞幠幡幢幦幨幩幪幬幭幯幰遙蹉跎餘庚鑑幵幷稚邃庀庁広庄庈庉笠庋跋庖犧庠庤庥鯨庬庱庳庴庵馨衢庹庿廃廄廆廋廌廎廏廐廑廒廕廖廛廝搏鑼廞弛袤廥廧廨廩廱綿踵髓廸廹甌鄴廻廼廾廿躔弁皺弇弌弍弎弐弒弔詭憾薦弝弢弣弤弨弭弮弰弳霖繇燾斌旭溥騫弶弸弼弾彀彄彆纍糾彊彔彖彘彟彠陌彤貽彧繪虹彪炳彫蔚鷗彰癉彲彳彴彷彷徉徨彸彽踩斂旆徂徇徊渭畬鉉裼従筌徘徙徜徠膳甦萌漸徬徭醺徯徳徴潘徻徼忀瘁胖燎怦悸顫扉犀澎湃砰恍惚絞隘忉憚挨餓忐忑忒忖応忝忞耿忡忪忭忮忱忸怩忻悠懣怏遏怔怗怚怛怞懟黍訝怫怭懦怱怲怳怵惕怸怹恁恂恇恉恌恏恒恓恔恘恚恛恝恞恟恠恣恧眄恪恫恬澹恰恿悀悁悃悄悆悊悐悒晦悚悛悜悝悤您悩悪悮悰悱悽惻悳悴悵惘悶悻悾惄愫鍾蒐惆惇惌惎惏惓惔惙惛耄惝瘧濁惥惦惪惲惴惷惸拈愀愃愆愈愊愍愐愑愒愓愔愕愙氓蠢騃昵愜赧愨愬愮愯愷愼慁慂慅慆慇靄慉慊慍慝慥慪慫慬慱慳慴慵慷慼焚憀灼鬱憃憊憋憍眺捏軾憒憔憖憙憧憬憨憪憭憮憯憷憸憹憺懃懅懆邀懊懋懌懍懐懞懠懤懥懨懫懮懰懱毖懵遁樑雍懺懽戁戄戇戉戔戕戛戝戞戠戡戢戣戤戥戦戩戭戯轟戱披菊牖戸戹戺戻戼戽鍬扂楔扃扆扈扊杖牽絹銬鐲賚扐摟攪烊盹瞌跟躉鑔靶鼾払扗玫腮扛扞扠扡扢盔押扤扦扱罾揄綏鞍郤窾扻扼扽抃抆抈抉抌抏瞎抔繯縊擻抜抝択抨摔歉躥牾抶抻搐泵菸拃拄拊髀拋拌脯拎拏拑擢秧沓曳攣迂拚拝拠拡拫拭拮踢拴拶拷攢拽掇芥橐簪摹疔挈瓢驥捺蹻挌挍挎挐揀挓挖掘浚挙揍聵挲挶挾挿捂捃捄捅捆捉捋胳膊揎捌捍捎軀蛛捗捘捙捜捥捩捫捭据捱捻捼捽掀掂掄臀膘掊掎掏掐笙掔掗掞棉芍掤搪闡掫掮掯揉掱掲掽掾揃揅揆搓揌諢揕揗揘揜揝揞揠揥揩揪揫櫫遒麈揰揲揵揶揸揹揺搆搉搊搋搌搎搔搕撼櫓搗搘搠搡搢搣搤搥搦搧搨搬楦褳訕赸搯搰搲搳搴搵搷搽搾搿摀摁摂摃摎摑摒摓跤摙摛摜摞摠摦睺羯摭摮摯摰摲摳摴摶摷摻摽撂撃撅稻撊撋撏鐧潑撕撙撚撝撟撢撣撦撧撩撬撱朔撳蚍蜉撾撿擀擄闖擉缶觚擐擕擖擗擡擣擤澡腚擧擨擩擫擭擯擰擷擸擼擽擿攃攄攆攉攥攐攓攖攙攛每攩攫轡澄攮攰攲攴軼攷砭訐攽碘敁敃敇敉敍敎筏敔敕敖閏誨敜煌敧敪敱敹敺敻敿斁衽斄牒縐謅斉斎斕鶉讕駮鱧斒筲斛斝斞斠斡斢斨斫斮晾沂潟穎絳邵斲斸釳於琅斾斿旀旂旃旄渦旌旎旐旒旓旖旛旝旟旡旣浴旰獺魃旴旹旻旼旽昀昃昄昇昉晰躲澈熹皎皓礬昑昕昜昝昞昡昤暉筍昦昨昰昱昳昴昶昺昻晁蹇隧蔬髦晄晅晒晛晜晞晟晡晢晤晥曦晩萘瑩顗晿暁暋暌暍暐暔暕煅暘暝暠暡曚暦暨暪朦朧暱暲殄馮暵暸暹暻暾曀曄曇曈曌曏曐曖曘曙曛曡曨曩駱曱甴肱曷牘禺錕曽滄耽朁朅朆杪栓誇竟粘絛朊膺朏朐朓朕朘朙瞄覲溘饔飧朠朢朣柵椆澱蝨朩朮朰朱炆璋鈺熾鹮朳槿朶朾朿杅杇杌隉欣釗湛漼楷瀍煜玟纓翱肈舜贄适逵杓杕杗杙荀蘅杝杞脩珓筊杰榔狍閦顰緬莞杲杳眇杴杶杸杻杼枋枌枒枓衾葄翹紓逋枙狸椏枟槁枲枳枴枵枷枸櫞枹枻柁柂柃柅柈柊柎某柑橘柒柘柙柚柜柞櫟柟柢柣柤柩柬柮柰柲橙柶柷柸柺査柿栃栄栒栔栘栝栟栢栩栫栭栱栲栳栴檀栵栻桀驁桁鎂桄桉桋桎梏椹葚桓桔桕桜桟桫欏桭桮桯桲桴桷桹湘溟梃梊梍梐潼梔梘梜梠梡梣梧梩梱梲梳梴梵梹棁棃櫻棐棑棕櫚簑繃蓑棖棘棜棨棩棪棫棬棯棰棱棳棸棹槨棼椀椄苕椈椊椋椌椐椑椓椗検椤椪椰椳椴椵椷椸椽椿楀楄楅篪楋楍楎楗楘楙楛楝楟楠楢楥楨楩楪楫楬楮楯楰楳楸楹楻楽榀榃榊榎槺榕榖榘榛狉莽榜笞榠榡榤榥榦榧榪榭榰榱槤霰榼榾榿槊閂槎槑槔槖様槜槢槥槧槪槭槮槱槲槻槼槾樆樊樏樑樕樗樘樛樟樠樧樨権樲樴樵猢猻樺樻罍樾樿橁橄橆橈笥龠橕橚橛輛橢橤橧豎膈跨橾橿檁檃檇檉檍檎檑檖檗檜檟檠檣檨檫檬檮檳檴檵檸櫂櫆櫌櫛櫜櫝櫡櫧櫨櫪櫬櫳櫹櫺茄櫽欀欂欃欐欑欒欙欞溴欨欬欱欵欶欷歔欸欹欻欼欿歁歃歆艎歈歊蒔蝶歓歕歘歙歛歜歟歠蹦詮鑲蹣跚陞陟歩歮歯歰歳歴璞歺瞑歾歿殀殈殍殑殗殜殙殛殞殢殣殥殪殫殭殰殳荃殷殸殹蛟殻殽謗毆毈毉餵毎毑蕈毗毘毚茛鄧毧毬毳毷毹毽毾毿氂氄氆靴氉氊氌氍氐聊氕氖気氘氙氚氛氜氝氡洶焊痙氤氳氥氦鋁鋅氪烴氬銨痤汪滸漉痘盂碾菖蒲蕹蛭螅氵氷氹氺氽燙氾氿渚汆汊汋汍汎汏汐汔汕褟汙汚汜蘺沼穢衊汧汨汩汭汲汳汴隄汾沄沅沆瀣沇沈葆浸淪湎溺痼痾沌沍沏沐沔沕沘浜畹礫沚沢沬沭沮沰沱灢沴沷籽沺烹濡洄泂肛泅泆湧肓泐泑泒泓泔泖泙泚泜泝泠漩饃濤粼濘蘚鰍泩泫泭泯銖泱泲洇洊涇琵琶荽薊箔洌洎洏洑潄濯洙洚洟洢洣洧洨洩痢滔洫洮洳洴洵洸洹洺洼洿淌蜚浄浉浙贛渫浠浡浤浥淼瀚浬浭翩萍浯浰蜃淀苔蛞蝓蜇螵蛸煲鯉浹浼浽溦涂涊涐涑涒涔滂涖涘涙涪涫涬涮涴涶涷涿淄淅淆淊淒黯淓淙漣淜淝淟淠淢淤淥淦淩猥藿褻淬淮淯淰淳詣淶紡淸淹燉癯綺渇済渉渋渓渕渙渟渢滓渤澥渧渨渮渰渲渶渼湅湉湋湍湑湓湔黔湜湝湞湟湢湣湩湫湮麟湱湲湴湼満溈溍溎溏溛舐漭溠溤溧馴溮溱溲溳溵溷溻溼溽溾滁滃滉滊滎滏稽滕滘滙滝滫滮羼耷滷滹滻煎漈漊漎繹漕漖漘漙漚漜漪漾漥漦漯漰漵漶漷濞潀潁潎潏潕潗潚潝潞潠潦祉瘍潲潵潷潸潺潾潿澁澂澃澉澌澍澐澒澔澙澠澣澦澧澨澫澬澮澰澴澶澼熏郁濆濇濈濉濊貊濔疣濜濠濩觴濬濮盥濰濲濼瀁瀅瀆瀋瀌瀏瀒瀔瀕瀘瀛瀟瀠瀡瀦瀧瀨瀬瀰瀲瀳瀵瀹瀺瀼灃灄灉灋灒灕灖灝灞灠灤灥灨灩灪蜴灮燼獴灴灸灺炁炅魷炗炘炙炤炫疽烙釺炯炰炱炲炴炷燬炻烀烋瘴鯧烓烔焙烜烝烳飪烺焃焄耆焌焐焓焗焜焞焠焢焮焯焱焼煁煃煆煇煊熠煍熬煐煒煕煗燻礆霾煚煝煟煠煢矸煨瑣煬萁煳煺煻熀熅熇熉羆熒穹熗熘熛熜稔諳爍熤熨熯熰眶螞熲熳熸熿燀燁燂燄盞燊燋燏燔隼燖燜燠燡燦燨燮燹燻燽燿爇爊爓爚爝爟爨蟾爯爰爲爻爿爿牀牁牂牄牋牎牏牓牕釉牚腩蒡虻牠雖蠣牣牤牮牯牲牳牴牷牸牼絆牿靬犂犄犆犇犉犍犎犒犖犗犛犟犠犨犩犪犮犰狳犴犵犺狁甩狃狆狎狒獾狘狙黠狨狩狫狴狷狺狻豕狽蜘猁猇猈猊猋猓猖獗猗猘猙獰獁猞猟獕猭猱猲猳猷猸猹猺玃獀獃獉獍獏獐獒獘獙獚獜獝獞獠獢獣獧鼇蹊獪獫獬豸獮獯鬻獳獷獼玀玁菟玅玆玈珉糝禛郅玍玎玓瓅玔玕玖玗玘玞玠玡玢玤玥玦玨瑰玭玳瑁玶玷玹玼珂珇珈瑚珌饈饌珔珖珙珛珞珡珣珥珧珩珪珮珶珷珺珽琀琁隕琊琇琖琚琠琤琦琨琫琬琭琮琯琰琱琲瑯琹琺琿瑀瑂瑄瑉瑋瑑瑔瑗瑢瑭瑱瑲瑳瑽瑾瑿璀璨璁璅璆璈璉璊璐璘璚璝璟璠璡璥璦璩璪璫璯璲璵璸璺璿瓀瓔瓖瓘瓚瓛臍瓞瓠瓤瓧瓩瓮瓰瓱瓴瓸瓻瓼甀甁甃甄甇甋甍甎甏甑甒甓甔甕甖甗飴蔗甙詫鉅粱盎銹糰甡褥産甪甬甭甮甯鎧甹甽甾甿畀畁畇畈畊畋畎畓畚畛畟鄂畤畦畧荻畯畳畵畷畸畽畾疃疉疋疍疎簞疐疒疕疘疝疢疥疧疳疶疿痁痄痊痌痍痏痐痒痔痗瘢痚痠痡痣痦痩痭痯痱痳痵痻痿瘀瘂瘃瘈瘉瘊瘌瘏瘐瘓瘕瘖瘙瘚瘛瘲瘜瘝瘞瘠瘥瘨瘭瘮瘯瘰癧瘳癘瘵瘸瘺瘻瘼癃癆癇癈癎癐癔癙癜癠癤癥癩蟆癪癭癰発踔紺蔫酵皙砬砒翎翳蘞鎢鑞皚鵯駒鱀粵褶皀皁莢皃鎛皈皌皐皒硃皕皖皘皜皝皞皤皦皨皪皫皭糙綻皴皸皻皽盅盋盌盍盚盝踞盦盩鞦韆盬盭眦睜瞤盯盱眙裰盵盻睞眂眅眈眊県眑眕眚眛眞眢眣眭眳眴眵眹瞓眽郛睃睅睆睊睍睎睏睒睖睙睟睠睢睥睪睪睯睽睾瞇瞈瞋瞍逛瞏瞕瞖瞘瞜瞟瞠瞢瞫瞭瞳瞵瞷瞹瞽闍瞿矓矉矍鑠矔矗矙矚矞矟矠矣矧矬矯矰矱硪碇磙��舫阡、矼矽礓砃砅砆砉砍砑砕砝砟砠砢砦砧砩砫砮砳艏砵砹砼硇硌硍硎硏硐硒硜硤硨磲茚鋇硭硻硾碃碉碏碣碓碔碞碡碪碫碬碭碯碲碸碻礡磈磉磎磑磔磕磖磛磟磠磡磤磥蹭磪磬磴磵磹磻磽礀礄礅礌礐礚礜礞礤礧礮礱礲礵礽礿祂祄祅祆禳祊祍祏祓祔祕祗祘祛祧祫祲祻祼餌臠錮禂禇禋禑禔禕隋禖禘禚禜禝禠禡禢禤禥禨禫禰禴禸稈秈秊闈颯秌秏秕笈蘵賃秠秣秪秫秬秭秷秸稊稌稍稑稗稙稛稞稬稭稲稹稼顙稾穂穄穇穈穉穋穌貯穏穜穟穠穡穣穤穧穨穭穮穵穸窿闃窀窂窅窆窈窕窊窋窌窒窓窔窞窣窬黷蹙窰窳窴窵窶窸窻竁竃竈竑竜竝竦竪篦篾笆鮫竾笉笊笎笏笐靨笓笤籙笪笫笭笮笰笱笲笳笵笸笻筀筅筇筈筎筑筘筠筤筥筦筧筩筭筯筰筱筳筴讌筸箂箇箊箎箑箒箘箙箛箜篌箝箠箬鏃箯箴箾篁篔簹篘篙篚篛篜篝篟篠篡篢篥篧篨篭篰篲篳篴篶篹篼簀簁簃簆簉簋簌簏簜簟簠簥簦簨簬簰簸簻籊籐籒籓籔籖籚籛籜籣籥籧籩籪籫籯芾麴籵籸籹籼粁粃粋粑粔糲粛粞粢粧粨粲粳粺粻粽闢粿糅糆糈糌糍糒糔萼糗蛆蹋糢糨糬糭糯糱糴糶糸糺紃蹼鰹黴紆紈絝紉閩襻紑紕紘錠鳶鷂紝紞紟紥紩紬紱紲紵紽紾紿絁絃絅経絍絎絏縭褵絓絖絘絜絢絣螯絪絫聒絰絵絶絺絻絿綀綃綅綆綈綉綌綍綎綑綖綘継続緞綣綦綪綫綮綯綰罟蝽綷縩綹綾緁緄緅緆緇緋緌緎総緑緔緖緗緘緙緜緡緤緥緦纂緪緰緱緲緶緹縁縃縄縈縉縋縏縑縕縗縚縝縞縟縠縡縢縦縧縯縰騁縲縳縴縵縶縹縻衙縿繄繅繈繊繋繐繒繖繘繙繠繢繣繨繮繰繸繻繾纁纆纇纈纉纊纑纕纘纙纚纛缾罃罆罈罋罌罎罏罖罘罛罝罠罣罥罦罨罫罭鍰罳罶罹罻罽罿羂羃羇羋蕉51鴕羑羖羗羜羝羢羣羥羧羭羮羰羱羵羶羸藜鮐翀翃翄翊翌翏翕翛翟翡翣翥翦躚翪翫翬翮翯翺翽翾翿闆饕鴰鍁耋耇耎耏耑耒耜耔耞耡耤耨耩耪耬耰鬢耵聹聃聆聎聝聡聦聱聴聶聼閾聿肄肏肐肕腋肙肜肟肧胛肫肬肭肰肴肵肸肼胊胍胏胑胔胗胙胝胠銓胤胦胩胬胭胯胰胲胴胹胻胼胾脇脘脝脞脡脣脤脥脧脰脲脳腆腊腌臢腍腒腓腖腜腠腡腥腧腬腯踝蹬鐐腴腶蠕誹膂膃膆膇膋膔膕膗膙膟黐膣膦膫膰膴膵膷膾臃臄臇臈臌臐臑臓臕臖臙臛臝臞臧蓐詡臽臾臿舀舁鰟鮍舋舎舔舗舘舝舠舡舢舨舭舲舳舴舸舺艁艄艅艉艋艑艕艖艗艘艚艜艟艣艤艨艩艫艬艭荏艴艶艸艹艻艿芃芄芊萰陂藭芏芔芘芚蕙芟芣芤茉芧芨芩芪芮芰鰱芴芷芸蕘豢芼芿苄苒苘苙苜蓿苠苡苣蕒苤苧苪鎊苶苹苺苻苾茀茁范蠡萣茆茇茈茌茍茖茞茠茢茥茦菰茭茯茳藨茷藘茼荁荄荅荇荈菅蜢鴞荍荑荘荳荵荸薺莆莒莔莕莘莙莚莛莜莝莦莨菪莩莪莭莰莿菀菆菉菎菏菐菑菓菔菕菘菝菡菢菣菥蓂菧菫轂鎣菶菷菹醢菺菻菼菾萅萆萇萋萏萐萑萜萩萱萴萵萹萻葇葍葎葑葒葖葙葠葥葦葧葭葯葳葴葶葸葹葽蒄蒎蒓蘢薹蒞蒟蒻蒢蒦蒨蒭藁蒯蒱鉾蒴蒹蒺蒽蓀蓁蓆蓇蓊蓌蓍蓏蓓蓖蓧蓪蓫蓽跣藕蓯蓰蓱蓴蓷蓺蓼蔀蔂蔃蔆蔇蔉蔊蔋蔌蔎蔕蔘蔙蔞蔟鍔蔣雯蔦蔯蔳蔴蔵蔸蔾蕁蕆蕋蕍蕎蕐蕑蕓蕕蕖蕗蕝蕞蕠蕡蕢蕣蕤蕨蕳蕷蕸蕺蕻薀薁薃薅薆薈薉薌薏薐薔薖薘薙諤釵薜薠薢薤薧薨薫薬薳薶薷薸薽薾薿藄藇藋藎藐藙藚藟藦藳藴藶藷藾蘀蘁蘄蘋蘗蘘蘝蘤蘧蘩蘸蘼虀虆虍蟠虒虓虖虡虣虥虩虯虰蛵虵虷鱒虺虼蚆蚈蚋蚓蚔蚖蚘蚜蚡蚣蚧蚨蚩蚪蚯蚰蜒蚱蚳蚶蚹蚺蚻蚿蛀蛁蛄蛅蝮蛌蛍蛐蟮蛑蛓蛔蛘蛚蛜蛡蛣蜊蛩蛺蛻螫蜅蜆蜈蝣蜋蜍蜎蜑蠊蜛餞蜞蜣蜨蜩蜮蜱蜷蜺蜾蜿蝀蝃蝋蝌蝍蝎蝏蝗蝘蝙蝝鱝蝡蝤蝥蝯蝰蝱蝲蝴蝻螃蠏螄螉螋螒螓螗螘螙螚蟥螟螣螥螬螭螮螾螿蟀蟅蟈蟊蟋蟑蟓蟛蟜蟟蟢蟣蟨蟪蟭蟯蟳蟶蟷蟺蟿蠁蠂蠃蠆蠋蠐蠓蠔蠗蠙蠚蠛蠜蠧蠨蠩蠭蠮蠰蠲蠵蠸蠼蠽衁衂衄衇衈衉衋衎衒衕衖衚衞裳鈎衭衲衵衹衺衿袈裟袗袚袟袢袪袮袲袴袷袺袼褙袽裀裉裊裋裌裍裎裒裛裯裱裲裴裾褀褂褉褊褌褎褐褒褓褔褕褘褚褡褢褦褧褪褫褭褯褰褱襠褸褽褾襁襃襆襇襉襋襌襏襚襛襜襝襞襡襢襤襦襫襬襭襮襴襶襼襽襾覂覃覅覇覉覊覌覗覘覚覜覥覦覧覩覬覯覰観覿觔觕觖觜觽觝觡酲觩觫觭觱觳觶觷觼觾觿言賅訃訇訏訑訒詁託訧訬訳訹証訾詀詅詆譭詈詊詎詑詒詖詗詘詧詨詵詶詸詹詻詼詿誂誃誄鋤誆誋誑誒誖誙誚誥誧説読誯誶誾諂諄諆諌諍諏諑諕諗諛諝諞諟諠諡諴諵諶諼謄謆謇謌謍謏謑謖謚謡謦謪謫謳謷謼謾譁譅譆譈譊譌譒譔譖鑫譞譟譩譫譬譱譲譴譸譹譾讅讆讋讌讎讐讒讖讙讜讟谽豁豉豇豈豊豋豌豏豔豞豖豗豜豝豣豦豨豭豱豳豵豶豷豺豻貅貆貍貎貔貘貙貜貤饜貰餸貺賁賂賏賒賕賙賝賡賧賨賫鬭賮賵賸賺賻賾贇贉贐贔贕贗赬赭赱赳迄趁趂趄趐趑趒趔趡趦趫趮趯趲趴趵趷趹趺趿跁跂跅跆躓蹌跐跕跖跗跙跛跦跧跩跫跬跮跱跲跴跺跼跽踅踆踈踉踊踒���踘踜踟躇躕踠踡踣踤踥踦踧蹺踫踮踰踱踴踶踹踺踼踽躞蹁蹂躪蹎蹐蹓蹔蹕蹚蹜蹝蹟蹠蹡蹢躂蹧蹩蹪蹯鞠蹽躃躄躅躊躋躐躑躒躘躙躛躝躠躡躦躧躩躭躰躳躶軃軆輥軏軔軘軜軝齶転軥軨軭軱軲轆軷軹軺軽軿輀輂輦輅輇輈輓輗輙輜輞輠輤輬輭輮輳輴輵輶輹輼輾轀轇轏轑轒轔轕轖轗轘轙轝轞轢轤辠辢辤辵辶辺込辿迅迋迍麿迓迣迤邐迥迨迮迸迺迻迿逄逅逌逍逑逓逕逖逡逭逯逴逶逹遄遅遉遘遛遝遢遨遫遯遰遴遶遹遻邂邅邉邋邎邕邗邘邛邠邢邧邨邯鄲邰邲邳邴邶邷邽邾邿郃郄郇郈郔郕郗郙郚郜郝郞郟郠郢郪郫郯郰郲郳郴郷郹郾郿鄀鄄鄆鄇鄈鄋鄍鄎鄏鄐鄑鄒鄔鄕鄖鄗鄘鄚鄜鄞鄠鄢鄣鄤鄦鄩鄫鄬鄮鄯鄱鄶鄷鄹鄺鄻鄾鄿酃酅酆酇酈酊酋酎酏酐酣酔酕醄酖酗酞酡酢酤酩酴酹酺醁醅醆醊醍醐醑醓醖醝醞醡醤醨醪醭醯醰醱醲醴醵醸醹醼醽醾釂釃釅釆釈鱸鎦閶釓釔釕鈀釙鼢鼴釤釧釪釬釭釱釷釸釹鈁鈃鈄鈆鈇鈈鈊鈌鈐鈑鈒鈤鈥鈧鈬鈮鈰鈳鐺鈸鈹鈽鈿鉄鉆鉈鉋鉌鉍鉏鉑鉕鉚鉢鉥鉦鉨鉬鉭鉱鉲鉶鉸鉺鉼鉿銍銎銑銕鏤銚銛銠銣銤銥銦銧銩銪銫銭銰銲銶銻銼銾鋂鋃鋆鋈鋊鋌鋍鋏鋐鋑鋕鋘鋙鋝鋟鋦鋨鋩鋭鋮鋯鋰鋱鋳鋹鋺鋻鏰鐱錀錁錆錇錈錍錏錒錔錙錚錛錞錟錡錤錩錬録錸錼鍀鍆鍇鍉鍍鍏鍐鍘鍚鍛鍠鍤鍥鍩鍫鍭鍱鍴鍶鍹鍺鍼鍾鎄鎇鎉鎋鎌鎍鎏鎒鎓鎗鎘鎚鎞鎡鎤鎩鎪鎭鎯鎰鎳鎴鎵鎸鎹鎿鏇鏊鏌鏐鏑鏖鏗鏘鏚鏜鏝鏞鏠鏦鏨鏷鏸鏹鏻鏽鏾鐃鐄鐇鐏鐒鐓鐔鐗馗鐙鐝鐠鐡鐦鐨鐩鐫鐬鐱鐳鐶鐻鐽鐿鑀鑅鑌鑐鑕鑚鑛鑢鑤鑥鑪鑭鑯鑱鑴鑵鑷钁钃镻閆閈閌閎閒閔閗閟閡関閤閤閧閬閲閹閺閻閼閽閿闇闉闋闐闑闒闓闘闚闞闟闠闤闥阞阢阤阨阬阯阹阼阽陁陑陔陛陜陡陥陬騭陴険陼陾隂隃隈隒隗隞隠隣隤隩隮隰顴隳隷隹雂雈雉雊雎雑雒雗雘雚雝雟雩雰雱驛霂霅霈霊霑霒霓霙霝霢霣霤霨霩霪霫霮靁靆靉靑靚靣靦靪靮靰靳靷靸靺靼靿鞀鞃鞄鞌鞗鞙鞚鞝鞞鞡鞣鞨鞫鞬鞮鞶鞹鞾韃韅韉馱韍韎韔韖韘韝韞韡韣韭韮韱韹韺頀颳頄頇頊頍頎頏頒頖頞頠頫頬顱頯頲頴頼顇顋顑顒顓顔顕顚顜顢顣顬顳颭颮颱颶颸颺颻颽颾颿飀飂飈飌飜飡飣飤飥飩飫飮飱飶餀餂餄餎餇餈餑餔餕餖餗餚餛餜餟餠餤餧餩餪餫餬餮餱餲餳餺餻餼餽餿饁饅饇饉饊饍饎饐饘饟饢馘馥馝馡馣騮騾馵馹駃駄駅駆駉駋駑駓駔駗駘駙駜駡駢駪駬駰駴駸駹駽駾騂騄騅騆騉騋騍騏驎騑騒験騕騖騠騢騣騤騧驤騵騶騸騺驀驂驃驄驆驈驊驌驍驎驏驒驔驖驙驦驩驫骺鯁骫骭骯骱骴骶骷髏骾髁髂髄髆髈髐髑髕髖髙髝髞髟髡髣髧髪髫髭髯髲髳髹髺髽髾鬁鬃鬅鬈鬋鬎鬏鬐鬑鬒鬖鬗鬘鬙鬠鬣鬪鬫鬬鬮鬯鬰鬲鬵鬷魆魈魊魋魍魎魑魖鰾魛魟魣魦魨魬魴魵魸鮀鮁鮆鮌鮎鮑鮒鮓鮚鮞鮟鱇鮠鮦鮨鮪鮭鮶鮸鮿鯀鯄鯆鯇鯈鯔鯕鯖鯗鯙鯠鯤鯥鯫鯰鯷鯸鯿鰂鰆鶼鰉鰋鰐鰒鰕鰛鰜鰣鰤鰥鰦鰨鰩鰮鰳鰶鰷鱺鰼鰽鱀鱄鱅鱆鱈鱎鱐鱓鱔鱖鱘鱟鱠鱣鱨鱭鱮鱲鱵鱻鲅鳦鳧鳯鳲鳷鳻鴂鴃鴄鴆鴈鴎鴒鴔鴗鴛鴦鴝鵒鴟鴠鴢鴣鴥鴯鶓鴳鴴鴷鴽鵀鵁鵂鵓鵖鵙鵜鶘鵞鵟鵩鵪鵫鵵鵷鵻鵾鶂鶊鶏鶒鶖鶗鶡鶤鶦鶬鶱鶲鶵鶸鶹鶺鶿鷀鷁鷃鷄鷇鷈鷉鷊鷏鷓鷕鷖鷙鷞鷟鷥鷦鷯鷩鷫鷭鷳鷴鷽鷾鷿鸂鸇鸊鸏鸑鸒鸓鸕鸛鸜鸝鹸鹹鹺麀麂麃麄麇麋麌麐麑麒麚麛麝麤麩麪麫麮麯麰麺麾黁黈黌黢黒黓黕黙黝黟黥黦黧黮黰黱黲黶黹黻黼黽黿鼂鼃鼅鼈鼉鼏鼐鼒鼕鼖鼙鼚鼛鼡鼩鼱鼪鼫鼯鼷鼽齁齆齇齈齉齌齎齏齔齕齗齙齚齜齞齟齬齠齢齣齧齩齮齯齰齱齵齾龎龑龒龔龖龘龝龡龢龤'
20
+
21
+ assert len(simplified_charcters) == len(simplified_charcters)
22
+
23
+ s2t_dict = {}
24
+ t2s_dict = {}
25
+ for i, item in enumerate(simplified_charcters):
26
+ s2t_dict[item] = traditional_characters[i]
27
+ t2s_dict[traditional_characters[i]] = item
28
+
29
+
30
+ def tranditional_to_simplified(text: str) -> str:
31
+ return "".join(
32
+ [t2s_dict[item] if item in t2s_dict else item for item in text])
33
+
34
+
35
+ def simplified_to_traditional(text: str) -> str:
36
+ return "".join(
37
+ [s2t_dict[item] if item in s2t_dict else item for item in text])
38
+
39
+
40
+ if __name__ == "__main__":
41
+ text = "一般是指存取一個應用程式啟動時始終顯示在網站或網頁瀏覽器中的一個或多個初始網頁等畫面存在的站點"
42
+ print(text)
43
+ text_simple = tranditional_to_simplified(text)
44
+ print(text_simple)
45
+ text_traditional = simplified_to_traditional(text_simple)
46
+ print(text_traditional)
SongBloom/g2p/cn_zh_g2p/zh_normalization/chronology.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import re
15
+
16
+ from .num import DIGITS
17
+ from .num import num2str
18
+ from .num import verbalize_cardinal
19
+ from .num import verbalize_digit
20
+
21
+
22
+ def _time_num2str(num_string: str) -> str:
23
+ """A special case for verbalizing number in time."""
24
+ result = num2str(num_string.lstrip('0'))
25
+ if num_string.startswith('0'):
26
+ result = DIGITS['0'] + result
27
+ return result
28
+
29
+
30
+ # 时刻表达式
31
+ RE_TIME = re.compile(r'([0-1]?[0-9]|2[0-3])'
32
+ r':([0-5][0-9])'
33
+ r'(:([0-5][0-9]))?')
34
+
35
+ # 时间范围,如8:30-12:30
36
+ RE_TIME_RANGE = re.compile(r'([0-1]?[0-9]|2[0-3])'
37
+ r':([0-5][0-9])'
38
+ r'(:([0-5][0-9]))?'
39
+ r'(~|-)'
40
+ r'([0-1]?[0-9]|2[0-3])'
41
+ r':([0-5][0-9])'
42
+ r'(:([0-5][0-9]))?')
43
+
44
+
45
+ def replace_time(match) -> str:
46
+ """
47
+ Args:
48
+ match (re.Match)
49
+ Returns:
50
+ str
51
+ """
52
+
53
+ is_range = len(match.groups()) > 5
54
+
55
+ hour = match.group(1)
56
+ minute = match.group(2)
57
+ second = match.group(4)
58
+
59
+ if is_range:
60
+ hour_2 = match.group(6)
61
+ minute_2 = match.group(7)
62
+ second_2 = match.group(9)
63
+
64
+ result = f"{num2str(hour)}点"
65
+ if minute.lstrip('0'):
66
+ if int(minute) == 30:
67
+ result += "半"
68
+ else:
69
+ result += f"{_time_num2str(minute)}分"
70
+ if second and second.lstrip('0'):
71
+ result += f"{_time_num2str(second)}秒"
72
+
73
+ if is_range:
74
+ result += "至"
75
+ result += f"{num2str(hour_2)}点"
76
+ if minute_2.lstrip('0'):
77
+ if int(minute) == 30:
78
+ result += "半"
79
+ else:
80
+ result += f"{_time_num2str(minute_2)}分"
81
+ if second_2 and second_2.lstrip('0'):
82
+ result += f"{_time_num2str(second_2)}秒"
83
+
84
+ return result
85
+
86
+
87
+ RE_DATE = re.compile(r'(\d{4}|\d{2})年'
88
+ r'((0?[1-9]|1[0-2])月)?'
89
+ r'(((0?[1-9])|((1|2)[0-9])|30|31)([日号]))?')
90
+
91
+
92
+ def replace_date(match) -> str:
93
+ """
94
+ Args:
95
+ match (re.Match)
96
+ Returns:
97
+ str
98
+ """
99
+ year = match.group(1)
100
+ month = match.group(3)
101
+ day = match.group(5)
102
+ result = ""
103
+ if year:
104
+ result += f"{verbalize_digit(year)}年"
105
+ if month:
106
+ result += f"{verbalize_cardinal(month)}月"
107
+ if day:
108
+ result += f"{verbalize_cardinal(day)}{match.group(9)}"
109
+ return result
110
+
111
+
112
+ # 用 / 或者 - 分隔的 YY/MM/DD 或者 YY-MM-DD 日期
113
+ RE_DATE2 = re.compile(
114
+ r'(\d{4})([- /.])(0[1-9]|1[012])\2(0[1-9]|[12][0-9]|3[01])')
115
+
116
+
117
+ def replace_date2(match) -> str:
118
+ """
119
+ Args:
120
+ match (re.Match)
121
+ Returns:
122
+ str
123
+ """
124
+ year = match.group(1)
125
+ month = match.group(3)
126
+ day = match.group(4)
127
+ result = ""
128
+ if year:
129
+ result += f"{verbalize_digit(year)}年"
130
+ if month:
131
+ result += f"{verbalize_cardinal(month)}月"
132
+ if day:
133
+ result += f"{verbalize_cardinal(day)}日"
134
+ return result
SongBloom/g2p/cn_zh_g2p/zh_normalization/constants.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import re
15
+ import string
16
+
17
+ from pypinyin.constants import SUPPORT_UCS4
18
+
19
+ # 全角半角转换
20
+ # 英文字符全角 -> 半角映射表 (num: 52)
21
+ F2H_ASCII_LETTERS = {
22
+ ord(char) + 65248: ord(char)
23
+ for char in string.ascii_letters
24
+ }
25
+
26
+ # 英文字符半角 -> 全角映射表
27
+ H2F_ASCII_LETTERS = {value: key for key, value in F2H_ASCII_LETTERS.items()}
28
+
29
+ # 数字字符全角 -> 半角映射表 (num: 10)
30
+ F2H_DIGITS = {ord(char) + 65248: ord(char) for char in string.digits}
31
+ # 数字字符半角 -> 全角映射表
32
+ H2F_DIGITS = {value: key for key, value in F2H_DIGITS.items()}
33
+
34
+ # 标点符号全角 -> 半角映射表 (num: 32)
35
+ F2H_PUNCTUATIONS = {ord(char) + 65248: ord(char) for char in string.punctuation}
36
+ # 标点符号半角 -> 全角映射表
37
+ H2F_PUNCTUATIONS = {value: key for key, value in F2H_PUNCTUATIONS.items()}
38
+
39
+ # 空格 (num: 1)
40
+ F2H_SPACE = {'\u3000': ' '}
41
+ H2F_SPACE = {' ': '\u3000'}
42
+
43
+ # 非"有拼音的汉字"的字符串,可用于NSW提取
44
+ if SUPPORT_UCS4:
45
+ RE_NSW = re.compile(r'(?:[^'
46
+ r'\u3007' # 〇
47
+ r'\u3400-\u4dbf' # CJK扩展A:[3400-4DBF]
48
+ r'\u4e00-\u9fff' # CJK基本:[4E00-9FFF]
49
+ r'\uf900-\ufaff' # CJK兼容:[F900-FAFF]
50
+ r'\U00020000-\U0002A6DF' # CJK扩展B:[20000-2A6DF]
51
+ r'\U0002A703-\U0002B73F' # CJK扩展C:[2A700-2B73F]
52
+ r'\U0002B740-\U0002B81D' # CJK扩展D:[2B740-2B81D]
53
+ r'\U0002F80A-\U0002FA1F' # CJK兼容扩展:[2F800-2FA1F]
54
+ r'])+')
55
+ else:
56
+ RE_NSW = re.compile( # pragma: no cover
57
+ r'(?:[^'
58
+ r'\u3007' # 〇
59
+ r'\u3400-\u4dbf' # CJK扩展A:[3400-4DBF]
60
+ r'\u4e00-\u9fff' # CJK基本:[4E00-9FFF]
61
+ r'\uf900-\ufaff' # CJK兼容:[F900-FAFF]
62
+ r'])+')
SongBloom/g2p/cn_zh_g2p/zh_normalization/num.py ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """
15
+ Rules to verbalize numbers into Chinese characters.
16
+ https://zh.wikipedia.org/wiki/中文数字#現代中文
17
+ """
18
+ import re
19
+ from collections import OrderedDict
20
+ from typing import List
21
+
22
+ DIGITS = {str(i): tran for i, tran in enumerate('零一二三四五六七八九')}
23
+ UNITS = OrderedDict({
24
+ 1: '十',
25
+ 2: '百',
26
+ 3: '千',
27
+ 4: '万',
28
+ 8: '亿',
29
+ })
30
+
31
+ COM_QUANTIFIERS = '(封|艘|把|目|套|段|人|所|朵|匹|张|座|回|场|尾|条|个|首|阙|阵|网|炮|顶|丘|棵|只|支|袭|辆|挑|担|颗|壳|窠|曲|墙|群|腔|砣|座|客|贯|扎|捆|刀|令|打|手|罗|坡|山|岭|江|溪|钟|队|单|双|对|出|口|头|脚|板|跳|枝|件|贴|针|线|管|名|位|身|堂|课|本|页|家|户|层|丝|毫|厘|分|钱|两|斤|担|铢|石|钧|锱|忽|(千|毫|微)克|毫|厘|(公)分|分|寸|尺|丈|里|寻|常|铺|程|(千|分|厘|毫|微)米|米|撮|勺|合|升|斗|石|盘|碗|碟|叠|桶|笼|盆|盒|杯|钟|斛|锅|簋|篮|盘|桶|罐|瓶|壶|卮|盏|箩|箱|煲|啖|袋|钵|年|月|日|季|刻|时|周|天|秒|分|小时|旬|纪|岁|世|更|夜|春|夏|秋|冬|代|伏|辈|丸|泡|粒|颗|幢|堆|条|根|支|道|面|片|张|颗|块|元|(亿|千万|百万|万|千|百)|(亿|千万|百万|万|千|百|美|)元|(亿|千万|百万|万|千|百|十|)吨|(亿|千万|百万|万|千|百|)块|角|毛|分)'
32
+
33
+ # 分数表达式
34
+ RE_FRAC = re.compile(r'(-?)(\d+)/(\d+)')
35
+
36
+
37
+ def replace_frac(match) -> str:
38
+ """
39
+ Args:
40
+ match (re.Match)
41
+ Returns:
42
+ str
43
+ """
44
+ sign = match.group(1)
45
+ nominator = match.group(2)
46
+ denominator = match.group(3)
47
+ sign: str = "负" if sign else ""
48
+ nominator: str = num2str(nominator)
49
+ denominator: str = num2str(denominator)
50
+ result = f"{sign}{denominator}分之{nominator}"
51
+ return result
52
+
53
+
54
+ # 百分数表达式
55
+ RE_PERCENTAGE = re.compile(r'(-?)(\d+(\.\d+)?)%')
56
+
57
+
58
+ def replace_percentage(match) -> str:
59
+ """
60
+ Args:
61
+ match (re.Match)
62
+ Returns:
63
+ str
64
+ """
65
+ sign = match.group(1)
66
+ percent = match.group(2)
67
+ sign: str = "负" if sign else ""
68
+ percent: str = num2str(percent)
69
+ result = f"{sign}百分之{percent}"
70
+ return result
71
+
72
+
73
+ # 整数表达式
74
+ # 带负号的整数 -10
75
+ RE_INTEGER = re.compile(r'(-)' r'(\d+)')
76
+
77
+
78
+ def replace_negative_num(match) -> str:
79
+ """
80
+ Args:
81
+ match (re.Match)
82
+ Returns:
83
+ str
84
+ """
85
+ sign = match.group(1)
86
+ number = match.group(2)
87
+ sign: str = "负" if sign else ""
88
+ number: str = num2str(number)
89
+ result = f"{sign}{number}"
90
+ return result
91
+
92
+
93
+ # 编号-无符号整形
94
+ # 00078
95
+ RE_DEFAULT_NUM = re.compile(r'\d{3}\d*')
96
+
97
+
98
+ def replace_default_num(match):
99
+ """
100
+ Args:
101
+ match (re.Match)
102
+ Returns:
103
+ str
104
+ """
105
+ number = match.group(0)
106
+ return verbalize_digit(number, alt_one=True)
107
+
108
+
109
+ # 加减乘除
110
+ RE_ASMD = re.compile(
111
+ r'((-?)((\d+)(\.\d+)?)|(\.(\d+)))([\+\-\×÷=])((-?)((\d+)(\.\d+)?)|(\.(\d+)))')
112
+ asmd_map = {
113
+ '+': '加',
114
+ '-': '减',
115
+ '×': '乘',
116
+ '÷': '除',
117
+ '=': '等于'
118
+ }
119
+
120
+
121
+ def replace_asmd(match) -> str:
122
+ """
123
+ Args:
124
+ match (re.Match)
125
+ Returns:
126
+ str
127
+ """
128
+ result = match.group(1) + asmd_map[match.group(8)] + match.group(9)
129
+ return result
130
+
131
+
132
+ # 数字表达式
133
+ # 纯小数
134
+ RE_DECIMAL_NUM = re.compile(r'(-?)((\d+)(\.\d+))' r'|(\.(\d+))')
135
+ # 正整数 + 量词
136
+ RE_POSITIVE_QUANTIFIERS = re.compile(r"(\d+)([多余几\+])?" + COM_QUANTIFIERS)
137
+ RE_NUMBER = re.compile(r'(-?)((\d+)(\.\d+)?)' r'|(\.(\d+))')
138
+
139
+
140
+ def replace_positive_quantifier(match) -> str:
141
+ """
142
+ Args:
143
+ match (re.Match)
144
+ Returns:
145
+ str
146
+ """
147
+ number = match.group(1)
148
+ match_2 = match.group(2)
149
+ if match_2 == "+":
150
+ match_2 = "多"
151
+ match_2: str = match_2 if match_2 else ""
152
+ quantifiers: str = match.group(3)
153
+ number: str = num2str(number)
154
+ result = f"{number}{match_2}{quantifiers}"
155
+ return result
156
+
157
+
158
+ def replace_number(match) -> str:
159
+ """
160
+ Args:
161
+ match (re.Match)
162
+ Returns:
163
+ str
164
+ """
165
+ sign = match.group(1)
166
+ number = match.group(2)
167
+ pure_decimal = match.group(5)
168
+ if pure_decimal:
169
+ result = num2str(pure_decimal)
170
+ else:
171
+ sign: str = "负" if sign else ""
172
+ number: str = num2str(number)
173
+ result = f"{sign}{number}"
174
+ return result
175
+
176
+
177
+ # 范围表达式
178
+ # match.group(1) and match.group(8) are copy from RE_NUMBER
179
+
180
+ RE_RANGE = re.compile(
181
+ r"""
182
+ (?<![\d\+\-\×÷=]) # 使用反向前瞻以确保数字范围之前没有其他数字和操作符
183
+ ((-?)((\d+)(\.\d+)?)) # 匹配范围起始的负数或正数(整数或小数)
184
+ [-~] # 匹配范围分隔符
185
+ ((-?)((\d+)(\.\d+)?)) # 匹配范围结束的负数或正数(整数或小数)
186
+ (?![\d\+\-\×÷=]) # 使用正向前瞻以确保数字范围之后没有其他数字和操作符
187
+ """, re.VERBOSE)
188
+
189
+
190
+ def replace_range(match) -> str:
191
+ """
192
+ Args:
193
+ match (re.Match)
194
+ Returns:
195
+ str
196
+ """
197
+ first, second = match.group(1), match.group(6)
198
+ first = RE_NUMBER.sub(replace_number, first)
199
+ second = RE_NUMBER.sub(replace_number, second)
200
+ result = f"{first}到{second}"
201
+ return result
202
+
203
+
204
+ # ~至表达式
205
+ RE_TO_RANGE = re.compile(
206
+ r'((-?)((\d+)(\.\d+)?)|(\.(\d+)))(%|°C|℃|度|摄氏度|cm2|cm²|cm3|cm³|cm|db|ds|kg|km|m2|m²|m³|m3|ml|m|mm|s)[~]((-?)((\d+)(\.\d+)?)|(\.(\d+)))(%|°C|℃|度|摄氏度|cm2|cm²|cm3|cm³|cm|db|ds|kg|km|m2|m²|m³|m3|ml|m|mm|s)')
207
+
208
+ def replace_to_range(match) -> str:
209
+ """
210
+ Args:
211
+ match (re.Match)
212
+ Returns:
213
+ str
214
+ """
215
+ result = match.group(0).replace('~', '至')
216
+ return result
217
+
218
+
219
+ def _get_value(value_string: str, use_zero: bool=True) -> List[str]:
220
+ stripped = value_string.lstrip('0')
221
+ if len(stripped) == 0:
222
+ return []
223
+ elif len(stripped) == 1:
224
+ if use_zero and len(stripped) < len(value_string):
225
+ return [DIGITS['0'], DIGITS[stripped]]
226
+ else:
227
+ return [DIGITS[stripped]]
228
+ else:
229
+ largest_unit = next(
230
+ power for power in reversed(UNITS.keys()) if power < len(stripped))
231
+ first_part = value_string[:-largest_unit]
232
+ second_part = value_string[-largest_unit:]
233
+ return _get_value(first_part) + [UNITS[largest_unit]] + _get_value(
234
+ second_part)
235
+
236
+
237
+ def verbalize_cardinal(value_string: str) -> str:
238
+ if not value_string:
239
+ return ''
240
+
241
+ # 000 -> '零' , 0 -> '零'
242
+ value_string = value_string.lstrip('0')
243
+ if len(value_string) == 0:
244
+ return DIGITS['0']
245
+
246
+ result_symbols = _get_value(value_string)
247
+ # verbalized number starting with '一十*' is abbreviated as `十*`
248
+ if len(result_symbols) >= 2 and result_symbols[0] == DIGITS[
249
+ '1'] and result_symbols[1] == UNITS[1]:
250
+ result_symbols = result_symbols[1:]
251
+ return ''.join(result_symbols)
252
+
253
+
254
+ def verbalize_digit(value_string: str, alt_one=False) -> str:
255
+ result_symbols = [DIGITS[digit] for digit in value_string]
256
+ result = ''.join(result_symbols)
257
+ if alt_one:
258
+ result = result.replace("一", "幺")
259
+ return result
260
+
261
+
262
+ def num2str(value_string: str) -> str:
263
+ integer_decimal = value_string.split('.')
264
+ if len(integer_decimal) == 1:
265
+ integer = integer_decimal[0]
266
+ decimal = ''
267
+ elif len(integer_decimal) == 2:
268
+ integer, decimal = integer_decimal
269
+ else:
270
+ raise ValueError(
271
+ f"The value string: '${value_string}' has more than one point in it."
272
+ )
273
+
274
+ result = verbalize_cardinal(integer)
275
+
276
+ decimal = decimal.rstrip('0')
277
+ if decimal:
278
+ # '.22' is verbalized as '零点二二'
279
+ # '3.20' is verbalized as '三点二
280
+ result = result if result else "零"
281
+ result += '点' + verbalize_digit(decimal)
282
+ return result
SongBloom/g2p/cn_zh_g2p/zh_normalization/phonecode.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import re
15
+
16
+ from .num import verbalize_digit
17
+
18
+ # 规范化固话/手机号码
19
+ # 手机
20
+ # http://www.jihaoba.com/news/show/13680
21
+ # 移动:139、138、137、136、135、134、159、158、157、150、151、152、188、187、182、183、184、178、198
22
+ # 联通:130、131、132、156、155、186、185、176
23
+ # 电信:133、153、189、180、181、177
24
+ RE_MOBILE_PHONE = re.compile(
25
+ r"(?<!\d)((\+?86 ?)?1([38]\d|5[0-35-9]|7[678]|9[89])\d{8})(?!\d)")
26
+ RE_TELEPHONE = re.compile(
27
+ r"(?<!\d)((0(10|2[1-3]|[3-9]\d{2})-?)?[1-9]\d{6,7})(?!\d)")
28
+
29
+ # 全国统一的号码400开头
30
+ RE_NATIONAL_UNIFORM_NUMBER = re.compile(r"(400)(-)?\d{3}(-)?\d{4}")
31
+
32
+
33
+ def phone2str(phone_string: str, mobile=True) -> str:
34
+ if mobile:
35
+ sp_parts = phone_string.strip('+').split()
36
+ result = ','.join(
37
+ [verbalize_digit(part, alt_one=True) for part in sp_parts])
38
+ return result
39
+ else:
40
+ sil_parts = phone_string.split('-')
41
+ result = ','.join(
42
+ [verbalize_digit(part, alt_one=True) for part in sil_parts])
43
+ return result
44
+
45
+
46
+ def replace_phone(match) -> str:
47
+ """
48
+ Args:
49
+ match (re.Match)
50
+ Returns:
51
+ str
52
+ """
53
+ return phone2str(match.group(0), mobile=False)
54
+
55
+
56
+ def replace_mobile(match) -> str:
57
+ """
58
+ Args:
59
+ match (re.Match)
60
+ Returns:
61
+ str
62
+ """
63
+ return phone2str(match.group(0))
SongBloom/g2p/cn_zh_g2p/zh_normalization/quantifier.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import re
15
+
16
+ from .num import num2str
17
+
18
+ # 温度表达式,温度会影响负号的读法
19
+ # -3°C 零下三度
20
+ RE_TEMPERATURE = re.compile(r'(-?)(\d+(\.\d+)?)(°C|℃|度|摄氏度)')
21
+ measure_dict = {
22
+ "cm2": "平方厘米",
23
+ "cm²": "平方厘米",
24
+ "cm3": "立方厘米",
25
+ "cm³": "立方厘米",
26
+ "cm": "厘米",
27
+ "db": "分贝",
28
+ "ds": "毫秒",
29
+ "kg": "千克",
30
+ "km": "千米",
31
+ "m2": "平方米",
32
+ "m²": "平方米",
33
+ "m³": "立方米",
34
+ "m3": "立方米",
35
+ "ml": "毫升",
36
+ "m": "米",
37
+ "mm": "毫米",
38
+ "s": "秒"
39
+ }
40
+
41
+
42
+ def replace_temperature(match) -> str:
43
+ """
44
+ Args:
45
+ match (re.Match)
46
+ Returns:
47
+ str
48
+ """
49
+ sign = match.group(1)
50
+ temperature = match.group(2)
51
+ unit = match.group(3)
52
+ sign: str = "零下" if sign else ""
53
+ temperature: str = num2str(temperature)
54
+ unit: str = "摄氏度" if unit == "摄氏度" else "度"
55
+ result = f"{sign}{temperature}{unit}"
56
+ return result
57
+
58
+
59
+ def replace_measure(sentence) -> str:
60
+ for q_notation in measure_dict:
61
+ if q_notation in sentence:
62
+ sentence = sentence.replace(q_notation, measure_dict[q_notation])
63
+ return sentence
SongBloom/g2p/cn_zh_g2p/zh_normalization/text_normlization.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import re
15
+ from typing import List
16
+
17
+ from .char_convert import tranditional_to_simplified
18
+ from .chronology import RE_DATE
19
+ from .chronology import RE_DATE2
20
+ from .chronology import RE_TIME
21
+ from .chronology import RE_TIME_RANGE
22
+ from .chronology import replace_date
23
+ from .chronology import replace_date2
24
+ from .chronology import replace_time
25
+ from .constants import F2H_ASCII_LETTERS
26
+ from .constants import F2H_DIGITS
27
+ from .constants import F2H_SPACE
28
+ from .num import RE_DECIMAL_NUM
29
+ from .num import RE_DEFAULT_NUM
30
+ from .num import RE_FRAC
31
+ from .num import RE_INTEGER
32
+ from .num import RE_NUMBER
33
+ from .num import RE_PERCENTAGE
34
+ from .num import RE_POSITIVE_QUANTIFIERS
35
+ from .num import RE_RANGE
36
+ from .num import RE_TO_RANGE
37
+ from .num import RE_ASMD
38
+ from .num import replace_default_num
39
+ from .num import replace_frac
40
+ from .num import replace_negative_num
41
+ from .num import replace_number
42
+ from .num import replace_percentage
43
+ from .num import replace_positive_quantifier
44
+ from .num import replace_range
45
+ from .num import replace_to_range
46
+ from .num import replace_asmd
47
+ from .phonecode import RE_MOBILE_PHONE
48
+ from .phonecode import RE_NATIONAL_UNIFORM_NUMBER
49
+ from .phonecode import RE_TELEPHONE
50
+ from .phonecode import replace_mobile
51
+ from .phonecode import replace_phone
52
+ from .quantifier import RE_TEMPERATURE
53
+ from .quantifier import replace_measure
54
+ from .quantifier import replace_temperature
55
+
56
+
57
+ class TextNormalizer():
58
+ def __init__(self):
59
+ self.SENTENCE_SPLITOR = re.compile(r'([:、,;。?!,;?!][”’]?)')
60
+
61
+ def _split(self, text: str, lang="zh") -> List[str]:
62
+ """Split long text into sentences with sentence-splitting punctuations.
63
+ Args:
64
+ text (str): The input text.
65
+ Returns:
66
+ List[str]: Sentences.
67
+ """
68
+ # Only for pure Chinese here
69
+ if lang == "zh":
70
+ text = text.replace(" ", "")
71
+ # 过滤掉特殊字符
72
+ text = re.sub(r'[——《》【】<>{}()()#&@“”^_|\\]', '', text)
73
+ text = self.SENTENCE_SPLITOR.sub(r'\1\n', text)
74
+ text = text.strip()
75
+ sentences = [sentence.strip() for sentence in re.split(r'\n+', text)]
76
+ return sentences
77
+
78
+ def _post_replace(self, sentence: str) -> str:
79
+ sentence = sentence.replace('/', '每')
80
+ # sentence = sentence.replace('~', '至')
81
+ # sentence = sentence.replace('~', '至')
82
+ sentence = sentence.replace('①', '一')
83
+ sentence = sentence.replace('②', '二')
84
+ sentence = sentence.replace('③', '三')
85
+ sentence = sentence.replace('④', '四')
86
+ sentence = sentence.replace('⑤', '五')
87
+ sentence = sentence.replace('⑥', '六')
88
+ sentence = sentence.replace('⑦', '七')
89
+ sentence = sentence.replace('⑧', '八')
90
+ sentence = sentence.replace('⑨', '九')
91
+ sentence = sentence.replace('⑩', '十')
92
+ sentence = sentence.replace('α', '阿尔法')
93
+ sentence = sentence.replace('β', '贝塔')
94
+ sentence = sentence.replace('γ', '伽玛').replace('Γ', '伽玛')
95
+ sentence = sentence.replace('δ', '德尔塔').replace('Δ', '德尔塔')
96
+ sentence = sentence.replace('ε', '艾普西龙')
97
+ sentence = sentence.replace('ζ', '捷塔')
98
+ sentence = sentence.replace('η', '依塔')
99
+ sentence = sentence.replace('θ', '西塔').replace('Θ', '西塔')
100
+ sentence = sentence.replace('ι', '艾欧塔')
101
+ sentence = sentence.replace('κ', '喀帕')
102
+ sentence = sentence.replace('λ', '拉姆达').replace('Λ', '拉姆达')
103
+ sentence = sentence.replace('μ', '缪')
104
+ sentence = sentence.replace('ν', '拗')
105
+ sentence = sentence.replace('ξ', '克西').replace('Ξ', '克西')
106
+ sentence = sentence.replace('ο', '欧米克伦')
107
+ sentence = sentence.replace('π', '派').replace('Π', '派')
108
+ sentence = sentence.replace('ρ', '肉')
109
+ sentence = sentence.replace('ς', '西格玛').replace('Σ', '西格玛').replace(
110
+ 'σ', '西格玛')
111
+ sentence = sentence.replace('τ', '套')
112
+ sentence = sentence.replace('υ', '宇普西龙')
113
+ sentence = sentence.replace('φ', '服艾').replace('Φ', '服艾')
114
+ sentence = sentence.replace('χ', '器')
115
+ sentence = sentence.replace('ψ', '普赛').replace('Ψ', '普赛')
116
+ sentence = sentence.replace('ω', '欧米伽').replace('Ω', '欧米伽')
117
+ # re filter special characters, have one more character "-" than line 68
118
+ sentence = re.sub(r'[-——《》【】<=>{}()()#&@“”^_|\\]', '', sentence)
119
+ return sentence
120
+
121
+ def normalize_sentence(self, sentence: str) -> str:
122
+ # basic character conversions
123
+ sentence = tranditional_to_simplified(sentence)
124
+ sentence = sentence.translate(F2H_ASCII_LETTERS).translate(
125
+ F2H_DIGITS).translate(F2H_SPACE)
126
+
127
+ # number related NSW verbalization
128
+ sentence = RE_DATE.sub(replace_date, sentence)
129
+ sentence = RE_DATE2.sub(replace_date2, sentence)
130
+
131
+ # range first
132
+ sentence = RE_TIME_RANGE.sub(replace_time, sentence)
133
+ sentence = RE_TIME.sub(replace_time, sentence)
134
+
135
+ # 处理~波浪号作为至的替换
136
+ sentence = RE_TO_RANGE.sub(replace_to_range, sentence)
137
+ sentence = RE_TEMPERATURE.sub(replace_temperature, sentence)
138
+ sentence = replace_measure(sentence)
139
+ sentence = RE_FRAC.sub(replace_frac, sentence)
140
+ sentence = RE_PERCENTAGE.sub(replace_percentage, sentence)
141
+ sentence = RE_MOBILE_PHONE.sub(replace_mobile, sentence)
142
+
143
+ sentence = RE_TELEPHONE.sub(replace_phone, sentence)
144
+ sentence = RE_NATIONAL_UNIFORM_NUMBER.sub(replace_phone, sentence)
145
+
146
+ sentence = RE_RANGE.sub(replace_range, sentence)
147
+
148
+ # 处理加减乘除
149
+ while RE_ASMD.search(sentence):
150
+ sentence = RE_ASMD.sub(replace_asmd, sentence)
151
+
152
+ sentence = RE_INTEGER.sub(replace_negative_num, sentence)
153
+ sentence = RE_DECIMAL_NUM.sub(replace_number, sentence)
154
+ sentence = RE_POSITIVE_QUANTIFIERS.sub(replace_positive_quantifier,
155
+ sentence)
156
+ sentence = RE_DEFAULT_NUM.sub(replace_default_num, sentence)
157
+ sentence = RE_NUMBER.sub(replace_number, sentence)
158
+ sentence = self._post_replace(sentence)
159
+
160
+ return sentence
161
+
162
+ def normalize(self, text: str) -> List[str]:
163
+ sentences = self._split(text)
164
+ sentences = [self.normalize_sentence(sent) for sent in sentences]
165
+ return sentences
SongBloom/g2p/lyric_common.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys
2
+
3
+ sys.path.insert(0, os.path.dirname(__file__))
4
+ from pinyin.pinyin import G2P_PinYin
5
+ from cn_zh_g2p import G2P_Mix, symbols
6
+
7
+ key2processor = {
8
+ 'pinyin': G2P_PinYin(),
9
+ 'phoneme': G2P_Mix(),
10
+ }
11
+
12
+ valid_struct_type = ['[chorus]', '[verse]', '[bridge]']
13
+ start_struct_type = ['[intro]', '[start]']
14
+ end_struct_type = ['[outro]', '[end]']
15
+ conn_struct_type = ['[inst]', '[solo]', '[break]']
16
+
17
+ LABELS = {
18
+ '[intro]': 0,
19
+ '[outro]': 1,
20
+ '[bridge]': 2,
21
+ '[inst]': 3,
22
+ '[verse]': 4,
23
+ '[chorus]': 5,
24
+ '[silence]': 6,
25
+ }
26
+
27
+ NUMBERS = {
28
+ '0': ['零', 'zero'],
29
+ '1': ['一', 'one'],
30
+ '2': ['二', 'two'],
31
+ '3': ['三', 'three'],
32
+ '4': ['四', 'four'],
33
+ '5': ['五', 'five'],
34
+ '6': ['六', 'six'],
35
+ '7': ['七', 'seven'],
36
+ '8': ['八', 'eight'],
37
+ '9': ['九', 'nine']
38
+ }
39
+
40
+ def detect_structure(structure):
41
+ valid_start = ['start', 'intro']
42
+ valid_end = ['outro', 'end']
43
+ valid_instru = ['solo', 'inst', 'break']
44
+ valid_bridge = ['bridge']
45
+
46
+ if structure in ['verse', 'chorus', 'silence']:
47
+ return structure
48
+
49
+ if structure in valid_start:
50
+ return 'intro'
51
+ if structure in valid_end:
52
+ return 'outro'
53
+ if structure in valid_instru:
54
+ return 'inst'
55
+ if structure in valid_bridge:
56
+ return 'bridge'
57
+
58
+ def merge_structure(start_time, end_time, structure, lyric):
59
+ cnt = 1
60
+ while cnt < len(start_time):
61
+ if structure[cnt] == structure[cnt-1]:
62
+ end_time[cnt-1] = end_time[cnt]
63
+ if structure[cnt] not in ["verse", "chorus", "bridge"]:
64
+ del start_time[cnt]
65
+ del end_time[cnt]
66
+ del structure[cnt]
67
+ del lyric[cnt]
68
+ else:
69
+ cnt += 1
70
+ else:
71
+ cnt += 1
72
+
73
+ return start_time, end_time, structure, lyric
74
+
75
+
76
+ def is_struct_legal(struct, text):
77
+ if struct in valid_struct_type and text != "":
78
+ return True
79
+ elif struct not in valid_struct_type and text == "":
80
+ return True
81
+ return False
SongBloom/g2p/pinyin/__init__.py ADDED
@@ -0,0 +1,430 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .symbols import symbols
2
+
3
+
4
+
5
+
6
+ pinyin_dict = {
7
+ "a": ("^", "a"),
8
+ "ai": ("^", "ai"),
9
+ "an": ("^", "an"),
10
+ "ang": ("^", "ang"),
11
+ "ao": ("^", "ao"),
12
+ "ba": ("b", "a"),
13
+ "bai": ("b", "ai"),
14
+ "ban": ("b", "an"),
15
+ "bang": ("b", "ang"),
16
+ "bao": ("b", "ao"),
17
+ "be": ("b", "e"),
18
+ "bei": ("b", "ei"),
19
+ "ben": ("b", "en"),
20
+ "beng": ("b", "eng"),
21
+ "bi": ("b", "i"),
22
+ "bian": ("b", "ian"),
23
+ "biao": ("b", "iao"),
24
+ "bie": ("b", "ie"),
25
+ "bin": ("b", "in"),
26
+ "bing": ("b", "ing"),
27
+ "bo": ("b", "o"),
28
+ "bu": ("b", "u"),
29
+ "ca": ("c", "a"),
30
+ "cai": ("c", "ai"),
31
+ "can": ("c", "an"),
32
+ "cang": ("c", "ang"),
33
+ "cao": ("c", "ao"),
34
+ "ce": ("c", "e"),
35
+ "cen": ("c", "en"),
36
+ "ceng": ("c", "eng"),
37
+ "cha": ("ch", "a"),
38
+ "chai": ("ch", "ai"),
39
+ "chan": ("ch", "an"),
40
+ "chang": ("ch", "ang"),
41
+ "chao": ("ch", "ao"),
42
+ "che": ("ch", "e"),
43
+ "chen": ("ch", "en"),
44
+ "cheng": ("ch", "eng"),
45
+ "chi": ("ch", "iii"),
46
+ "chong": ("ch", "ong"),
47
+ "chou": ("ch", "ou"),
48
+ "chu": ("ch", "u"),
49
+ "chua": ("ch", "ua"),
50
+ "chuai": ("ch", "uai"),
51
+ "chuan": ("ch", "uan"),
52
+ "chuang": ("ch", "uang"),
53
+ "chui": ("ch", "uei"),
54
+ "chun": ("ch", "uen"),
55
+ "chuo": ("ch", "uo"),
56
+ "ci": ("c", "ii"),
57
+ "cong": ("c", "ong"),
58
+ "cou": ("c", "ou"),
59
+ "cu": ("c", "u"),
60
+ "cuan": ("c", "uan"),
61
+ "cui": ("c", "uei"),
62
+ "cun": ("c", "uen"),
63
+ "cuo": ("c", "uo"),
64
+ "da": ("d", "a"),
65
+ "dai": ("d", "ai"),
66
+ "dan": ("d", "an"),
67
+ "dang": ("d", "ang"),
68
+ "dao": ("d", "ao"),
69
+ "de": ("d", "e"),
70
+ "dei": ("d", "ei"),
71
+ "den": ("d", "en"),
72
+ "deng": ("d", "eng"),
73
+ "di": ("d", "i"),
74
+ "dia": ("d", "ia"),
75
+ "dian": ("d", "ian"),
76
+ "diao": ("d", "iao"),
77
+ "die": ("d", "ie"),
78
+ "ding": ("d", "ing"),
79
+ "diu": ("d", "iou"),
80
+ "dong": ("d", "ong"),
81
+ "dou": ("d", "ou"),
82
+ "du": ("d", "u"),
83
+ "duan": ("d", "uan"),
84
+ "dui": ("d", "uei"),
85
+ "dun": ("d", "uen"),
86
+ "duo": ("d", "uo"),
87
+ "e": ("^", "e"),
88
+ "ei": ("^", "ei"),
89
+ "en": ("^", "en"),
90
+ "ng": ("^", "en"),
91
+ "eng": ("^", "eng"),
92
+ "er": ("^", "er"),
93
+ "fa": ("f", "a"),
94
+ "fan": ("f", "an"),
95
+ "fang": ("f", "ang"),
96
+ "fei": ("f", "ei"),
97
+ "fen": ("f", "en"),
98
+ "feng": ("f", "eng"),
99
+ "fo": ("f", "o"),
100
+ "fou": ("f", "ou"),
101
+ "fu": ("f", "u"),
102
+ "ga": ("g", "a"),
103
+ "gai": ("g", "ai"),
104
+ "gan": ("g", "an"),
105
+ "gang": ("g", "ang"),
106
+ "gao": ("g", "ao"),
107
+ "ge": ("g", "e"),
108
+ "gei": ("g", "ei"),
109
+ "gen": ("g", "en"),
110
+ "geng": ("g", "eng"),
111
+ "gong": ("g", "ong"),
112
+ "gou": ("g", "ou"),
113
+ "gu": ("g", "u"),
114
+ "gua": ("g", "ua"),
115
+ "guai": ("g", "uai"),
116
+ "guan": ("g", "uan"),
117
+ "guang": ("g", "uang"),
118
+ "gui": ("g", "uei"),
119
+ "gun": ("g", "uen"),
120
+ "guo": ("g", "uo"),
121
+ "ha": ("h", "a"),
122
+ "hai": ("h", "ai"),
123
+ "han": ("h", "an"),
124
+ "hang": ("h", "ang"),
125
+ "hao": ("h", "ao"),
126
+ "he": ("h", "e"),
127
+ "hei": ("h", "ei"),
128
+ "hen": ("h", "en"),
129
+ "heng": ("h", "eng"),
130
+ "hong": ("h", "ong"),
131
+ "hou": ("h", "ou"),
132
+ "hu": ("h", "u"),
133
+ "hua": ("h", "ua"),
134
+ "huai": ("h", "uai"),
135
+ "huan": ("h", "uan"),
136
+ "huang": ("h", "uang"),
137
+ "hui": ("h", "uei"),
138
+ "hun": ("h", "uen"),
139
+ "huo": ("h", "uo"),
140
+ "ji": ("j", "i"),
141
+ "jia": ("j", "ia"),
142
+ "jian": ("j", "ian"),
143
+ "jiang": ("j", "iang"),
144
+ "jiao": ("j", "iao"),
145
+ "jie": ("j", "ie"),
146
+ "jin": ("j", "in"),
147
+ "jing": ("j", "ing"),
148
+ "jiong": ("j", "iong"),
149
+ "jiu": ("j", "iou"),
150
+ "ju": ("j", "v"),
151
+ "juan": ("j", "van"),
152
+ "jue": ("j", "ve"),
153
+ "jun": ("j", "vn"),
154
+ "ka": ("k", "a"),
155
+ "kai": ("k", "ai"),
156
+ "kan": ("k", "an"),
157
+ "kang": ("k", "ang"),
158
+ "kao": ("k", "ao"),
159
+ "ke": ("k", "e"),
160
+ "kei": ("k", "ei"),
161
+ "ken": ("k", "en"),
162
+ "keng": ("k", "eng"),
163
+ "kong": ("k", "ong"),
164
+ "kou": ("k", "ou"),
165
+ "ku": ("k", "u"),
166
+ "kua": ("k", "ua"),
167
+ "kuai": ("k", "uai"),
168
+ "kuan": ("k", "uan"),
169
+ "kuang": ("k", "uang"),
170
+ "kui": ("k", "uei"),
171
+ "kun": ("k", "uen"),
172
+ "kuo": ("k", "uo"),
173
+ "la": ("l", "a"),
174
+ "lai": ("l", "ai"),
175
+ "lan": ("l", "an"),
176
+ "lang": ("l", "ang"),
177
+ "lao": ("l", "ao"),
178
+ "le": ("l", "e"),
179
+ "lei": ("l", "ei"),
180
+ "leng": ("l", "eng"),
181
+ "li": ("l", "i"),
182
+ "lia": ("l", "ia"),
183
+ "lian": ("l", "ian"),
184
+ "liang": ("l", "iang"),
185
+ "liao": ("l", "iao"),
186
+ "lie": ("l", "ie"),
187
+ "lin": ("l", "in"),
188
+ "ling": ("l", "ing"),
189
+ "liu": ("l", "iou"),
190
+ "lo": ("l", "o"),
191
+ "long": ("l", "ong"),
192
+ "lou": ("l", "ou"),
193
+ "lu": ("l", "u"),
194
+ "lv": ("l", "v"),
195
+ "luan": ("l", "uan"),
196
+ "lve": ("l", "ve"),
197
+ "lue": ("l", "ve"),
198
+ "lun": ("l", "uen"),
199
+ "luo": ("l", "uo"),
200
+ "ma": ("m", "a"),
201
+ "mai": ("m", "ai"),
202
+ "man": ("m", "an"),
203
+ "mang": ("m", "ang"),
204
+ "mao": ("m", "ao"),
205
+ "me": ("m", "e"),
206
+ "mei": ("m", "ei"),
207
+ "men": ("m", "en"),
208
+ "meng": ("m", "eng"),
209
+ "mi": ("m", "i"),
210
+ "mian": ("m", "ian"),
211
+ "miao": ("m", "iao"),
212
+ "mie": ("m", "ie"),
213
+ "min": ("m", "in"),
214
+ "ming": ("m", "ing"),
215
+ "miu": ("m", "iou"),
216
+ "mo": ("m", "o"),
217
+ "mou": ("m", "ou"),
218
+ "mu": ("m", "u"),
219
+ "na": ("n", "a"),
220
+ "nai": ("n", "ai"),
221
+ "nan": ("n", "an"),
222
+ "nang": ("n", "ang"),
223
+ "nao": ("n", "ao"),
224
+ "ne": ("n", "e"),
225
+ "nei": ("n", "ei"),
226
+ "nen": ("n", "en"),
227
+ "neng": ("n", "eng"),
228
+ "ni": ("n", "i"),
229
+ "nia": ("n", "ia"),
230
+ "nian": ("n", "ian"),
231
+ "niang": ("n", "iang"),
232
+ "niao": ("n", "iao"),
233
+ "nie": ("n", "ie"),
234
+ "nin": ("n", "in"),
235
+ "ning": ("n", "ing"),
236
+ "niu": ("n", "iou"),
237
+ "nong": ("n", "ong"),
238
+ "nou": ("n", "ou"),
239
+ "nu": ("n", "u"),
240
+ "nv": ("n", "v"),
241
+ "nuan": ("n", "uan"),
242
+ "nve": ("n", "ve"),
243
+ "nue": ("n", "ve"),
244
+ "nuo": ("n", "uo"),
245
+ "o": ("^", "o"),
246
+ "ou": ("^", "ou"),
247
+ "pa": ("p", "a"),
248
+ "pai": ("p", "ai"),
249
+ "pan": ("p", "an"),
250
+ "pang": ("p", "ang"),
251
+ "pao": ("p", "ao"),
252
+ "pe": ("p", "e"),
253
+ "pei": ("p", "ei"),
254
+ "pen": ("p", "en"),
255
+ "peng": ("p", "eng"),
256
+ "pi": ("p", "i"),
257
+ "pian": ("p", "ian"),
258
+ "piao": ("p", "iao"),
259
+ "pie": ("p", "ie"),
260
+ "pin": ("p", "in"),
261
+ "ping": ("p", "ing"),
262
+ "po": ("p", "o"),
263
+ "pou": ("p", "ou"),
264
+ "pu": ("p", "u"),
265
+ "qi": ("q", "i"),
266
+ "qia": ("q", "ia"),
267
+ "qian": ("q", "ian"),
268
+ "qiang": ("q", "iang"),
269
+ "qiao": ("q", "iao"),
270
+ "qie": ("q", "ie"),
271
+ "qin": ("q", "in"),
272
+ "qing": ("q", "ing"),
273
+ "qiong": ("q", "iong"),
274
+ "qiu": ("q", "iou"),
275
+ "qu": ("q", "v"),
276
+ "quan": ("q", "van"),
277
+ "que": ("q", "ve"),
278
+ "qun": ("q", "vn"),
279
+ "ran": ("r", "an"),
280
+ "rang": ("r", "ang"),
281
+ "rao": ("r", "ao"),
282
+ "re": ("r", "e"),
283
+ "ren": ("r", "en"),
284
+ "reng": ("r", "eng"),
285
+ "ri": ("r", "iii"),
286
+ "rong": ("r", "ong"),
287
+ "rou": ("r", "ou"),
288
+ "ru": ("r", "u"),
289
+ "rua": ("r", "ua"),
290
+ "ruan": ("r", "uan"),
291
+ "rui": ("r", "uei"),
292
+ "run": ("r", "uen"),
293
+ "ruo": ("r", "uo"),
294
+ "sa": ("s", "a"),
295
+ "sai": ("s", "ai"),
296
+ "san": ("s", "an"),
297
+ "sang": ("s", "ang"),
298
+ "sao": ("s", "ao"),
299
+ "se": ("s", "e"),
300
+ "sen": ("s", "en"),
301
+ "seng": ("s", "eng"),
302
+ "sha": ("sh", "a"),
303
+ "shai": ("sh", "ai"),
304
+ "shan": ("sh", "an"),
305
+ "shang": ("sh", "ang"),
306
+ "shao": ("sh", "ao"),
307
+ "she": ("sh", "e"),
308
+ "shei": ("sh", "ei"),
309
+ "shen": ("sh", "en"),
310
+ "sheng": ("sh", "eng"),
311
+ "shi": ("sh", "iii"),
312
+ "shou": ("sh", "ou"),
313
+ "shu": ("sh", "u"),
314
+ "shua": ("sh", "ua"),
315
+ "shuai": ("sh", "uai"),
316
+ "shuan": ("sh", "uan"),
317
+ "shuang": ("sh", "uang"),
318
+ "shui": ("sh", "uei"),
319
+ "shun": ("sh", "uen"),
320
+ "shuo": ("sh", "uo"),
321
+ "si": ("s", "ii"),
322
+ "song": ("s", "ong"),
323
+ "sou": ("s", "ou"),
324
+ "su": ("s", "u"),
325
+ "suan": ("s", "uan"),
326
+ "sui": ("s", "uei"),
327
+ "sun": ("s", "uen"),
328
+ "suo": ("s", "uo"),
329
+ "ta": ("t", "a"),
330
+ "tai": ("t", "ai"),
331
+ "tan": ("t", "an"),
332
+ "tang": ("t", "ang"),
333
+ "tao": ("t", "ao"),
334
+ "te": ("t", "e"),
335
+ "tei": ("t", "ei"),
336
+ "teng": ("t", "eng"),
337
+ "ti": ("t", "i"),
338
+ "tian": ("t", "ian"),
339
+ "tiao": ("t", "iao"),
340
+ "tie": ("t", "ie"),
341
+ "ting": ("t", "ing"),
342
+ "tong": ("t", "ong"),
343
+ "tou": ("t", "ou"),
344
+ "tu": ("t", "u"),
345
+ "tuan": ("t", "uan"),
346
+ "tui": ("t", "uei"),
347
+ "tun": ("t", "uen"),
348
+ "tuo": ("t", "uo"),
349
+ "wa": ("^", "ua"),
350
+ "wai": ("^", "uai"),
351
+ "wan": ("^", "uan"),
352
+ "wang": ("^", "uang"),
353
+ "wei": ("^", "uei"),
354
+ "wen": ("^", "uen"),
355
+ "weng": ("^", "ueng"),
356
+ "wo": ("^", "uo"),
357
+ "wu": ("^", "u"),
358
+ "xi": ("x", "i"),
359
+ "xia": ("x", "ia"),
360
+ "xian": ("x", "ian"),
361
+ "xiang": ("x", "iang"),
362
+ "xiao": ("x", "iao"),
363
+ "xie": ("x", "ie"),
364
+ "xin": ("x", "in"),
365
+ "xing": ("x", "ing"),
366
+ "xiong": ("x", "iong"),
367
+ "xiu": ("x", "iou"),
368
+ "xu": ("x", "v"),
369
+ "xuan": ("x", "van"),
370
+ "xue": ("x", "ve"),
371
+ "xun": ("x", "vn"),
372
+ "ya": ("^", "ia"),
373
+ "yan": ("^", "ian"),
374
+ "yang": ("^", "iang"),
375
+ "yao": ("^", "iao"),
376
+ "ye": ("^", "ie"),
377
+ "yi": ("^", "i"),
378
+ "yin": ("^", "in"),
379
+ "ying": ("^", "ing"),
380
+ "yo": ("^", "iou"),
381
+ "yong": ("^", "iong"),
382
+ "you": ("^", "iou"),
383
+ "yu": ("^", "v"),
384
+ "yuan": ("^", "van"),
385
+ "yue": ("^", "ve"),
386
+ "yun": ("^", "vn"),
387
+ "za": ("z", "a"),
388
+ "zai": ("z", "ai"),
389
+ "zan": ("z", "an"),
390
+ "zang": ("z", "ang"),
391
+ "zao": ("z", "ao"),
392
+ "ze": ("z", "e"),
393
+ "zei": ("z", "ei"),
394
+ "zen": ("z", "en"),
395
+ "zeng": ("z", "eng"),
396
+ "zha": ("zh", "a"),
397
+ "zhai": ("zh", "ai"),
398
+ "zhan": ("zh", "an"),
399
+ "zhang": ("zh", "ang"),
400
+ "zhao": ("zh", "ao"),
401
+ "zhe": ("zh", "e"),
402
+ "zhei": ("zh", "ei"),
403
+ "zhen": ("zh", "en"),
404
+ "zheng": ("zh", "eng"),
405
+ "zhi": ("zh", "iii"),
406
+ "zhong": ("zh", "ong"),
407
+ "zhou": ("zh", "ou"),
408
+ "zhu": ("zh", "u"),
409
+ "zhua": ("zh", "ua"),
410
+ "zhuai": ("zh", "uai"),
411
+ "zhuan": ("zh", "uan"),
412
+ "zhuang": ("zh", "uang"),
413
+ "zhui": ("zh", "uei"),
414
+ "zhun": ("zh", "uen"),
415
+ "zhuo": ("zh", "uo"),
416
+ "zi": ("z", "ii"),
417
+ "zong": ("z", "ong"),
418
+ "zou": ("z", "ou"),
419
+ "zu": ("z", "u"),
420
+ "zuan": ("z", "uan"),
421
+ "zui": ("z", "uei"),
422
+ "zun": ("z", "uen"),
423
+ "zuo": ("z", "uo"),
424
+ }
425
+
426
+
427
+ def gen_vocabs():
428
+ import yaml
429
+ vocab = [f"<{c}{i}>" for c in list(pinyin_dict.keys()) for i in range(1,6)]
430
+ yaml.dump(vocab, open('./vocab.yaml', 'w'))
SongBloom/g2p/pinyin/pinyin.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+ from pypinyin import Style
4
+ from pypinyin.contrib.neutral_tone import NeutralToneWith5Mixin
5
+ from pypinyin.converter import DefaultConverter
6
+ from pypinyin.core import Pinyin
7
+
8
+ from . import pinyin_dict
9
+ import torch
10
+
11
+
12
+ class MyConverter(NeutralToneWith5Mixin, DefaultConverter):
13
+ pass
14
+
15
+
16
+ def is_chinese(uchar):
17
+ if uchar >= u'\u4e00' and uchar <= u'\u9fa5':
18
+ return True
19
+ else:
20
+ return False
21
+
22
+
23
+ def clean_chinese(text: str):
24
+ text = text.strip()
25
+ text_clean = []
26
+ for char in text:
27
+ if (is_chinese(char)):
28
+ text_clean.append(char)
29
+ else:
30
+ if len(text_clean) > 1 and is_chinese(text_clean[-1]):
31
+ text_clean.append(',')
32
+ text_clean = ''.join(text_clean).strip(',')
33
+ return text_clean
34
+
35
+
36
+ class G2P_PinYin():
37
+
38
+ def __init__(self):
39
+ super(G2P_PinYin, self).__init__()
40
+ self.pinyin_parser = Pinyin(MyConverter())
41
+
42
+ def get_phoneme4pinyin(self, pinyins):
43
+ result = []
44
+ count_phone = []
45
+ for pinyin in pinyins:
46
+ if pinyin[:-1] in pinyin_dict:
47
+ tone = pinyin[-1]
48
+ a = pinyin[:-1]
49
+ a1, a2 = pinyin_dict[a]
50
+ result += [a1, a2 + tone]
51
+ count_phone.append(2)
52
+ return result, count_phone
53
+
54
+ # def chinese_to_phonemes(self, text):
55
+ # text = clean_chinese(text)
56
+ # phonemes = ["sil"]
57
+ # chars = ['[PAD]']
58
+ # all_pinyins = []
59
+ # count_phone = []
60
+ # count_phone.append(1)
61
+ # for subtext in text.split(","):
62
+ # if (len(subtext) == 0):
63
+ # continue
64
+ # pinyins = self.correct_pinyin_tone3(subtext)
65
+ # all_pinyins.append(' '.join(pinyins))
66
+ # sub_p, sub_c = self.get_phoneme4pinyin(pinyins)
67
+ # phonemes.extend(sub_p)
68
+ # phonemes.append(",")
69
+ # count_phone.extend(sub_c)
70
+ # count_phone.append(1)
71
+ # chars.append(subtext)
72
+ # chars.append(',')
73
+ # phonemes.append("sil")
74
+ # count_phone.append(1)
75
+ # chars.append('[PAD]')
76
+ # # char_embeds = self.prosody.expand_for_phone(char_embeds, count_phone)
77
+ # return " ".join(phonemes), " ".join(chars), ' , '.join(all_pinyins)
78
+
79
+ def chinese_to_phonemes(self, text):
80
+ all_pinyins = []
81
+ subtext = []
82
+ for chr in text:
83
+ if is_chinese(chr):
84
+ subtext.append(chr)
85
+ else:
86
+ if subtext != []:
87
+ subtext = ''.join(subtext)
88
+ pinyins = self.correct_pinyin_tone3(subtext)
89
+ pinyins = [f"<{i}>" for i in pinyins]
90
+ all_pinyins.append(' '+ ' '.join(pinyins)+ ' ')
91
+ all_pinyins.append(chr)
92
+ subtext = []
93
+ if subtext != []:
94
+ subtext = ''.join(subtext)
95
+ pinyins = self.correct_pinyin_tone3(subtext)
96
+ pinyins = [f"<{i}>" for i in pinyins]
97
+ all_pinyins.append(' '+ ' '.join(pinyins)+ ' ')
98
+ # char_embeds = self.prosody.expand_for_phone(char_embeds, count_phone)
99
+ return ''.join(all_pinyins)
100
+
101
+ def correct_pinyin_tone3(self, text):
102
+ pinyin_list = [
103
+ p[0]
104
+ for p in self.pinyin_parser.pinyin(text,
105
+ style=Style.TONE3,
106
+ strict=False,
107
+ neutral_tone_with_five=True)
108
+ ]
109
+ if len(pinyin_list) >= 2:
110
+ for i in range(1, len(pinyin_list)):
111
+ try:
112
+ if re.findall(r'\d',
113
+ pinyin_list[i - 1])[0] == '3' and re.findall(
114
+ r'\d', pinyin_list[i])[0] == '3':
115
+ pinyin_list[i - 1] = pinyin_list[i - 1].replace(
116
+ '3', '2')
117
+ except IndexError:
118
+ pass
119
+ return pinyin_list
120
+
121
+ # def expand_for_phone(self, char_embeds, length): # length of phones for char
122
+ # if(char_embeds.size(0) > len(length)):
123
+ # print(char_embeds.shape, len(length))
124
+ # char_embeds = char_embeds[0:len(length),:]
125
+ # elif(char_embeds.size(0) < len(length)):
126
+ # print(char_embeds.shape, len(length))
127
+ # length = length[0:char_embeds.size(0)]
128
+ # expand_vecs = list()
129
+ # for vec, leng in zip(char_embeds, length):
130
+ # vec = vec.expand(leng, -1)
131
+ # expand_vecs.append(vec)
132
+ # expand_embeds = torch.cat(expand_vecs, 0)
133
+ # assert expand_embeds.size(0) == sum(length)
134
+ # return expand_embeds.numpy()
135
+
136
+ def __call__(self, text):
137
+ return self.chinese_to_phonemes(text)
SongBloom/g2p/pinyin/symbols.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _pause = ["sil", "eos", "sp", "#0", "#1", "#2", "#3"]
2
+
3
+ _initials = [
4
+ "^",
5
+ "b",
6
+ "c",
7
+ "ch",
8
+ "d",
9
+ "f",
10
+ "g",
11
+ "h",
12
+ "j",
13
+ "k",
14
+ "l",
15
+ "m",
16
+ "n",
17
+ "p",
18
+ "q",
19
+ "r",
20
+ "s",
21
+ "sh",
22
+ "t",
23
+ "x",
24
+ "z",
25
+ "zh",
26
+ ]
27
+
28
+ _tones = ["1", "2", "3", "4", "5"]
29
+
30
+ _finals = [
31
+ "a",
32
+ "ai",
33
+ "an",
34
+ "ang",
35
+ "ao",
36
+ "e",
37
+ "ei",
38
+ "en",
39
+ "eng",
40
+ "er",
41
+ "i",
42
+ "ia",
43
+ "ian",
44
+ "iang",
45
+ "iao",
46
+ "ie",
47
+ "ii",
48
+ "iii",
49
+ "in",
50
+ "ing",
51
+ "iong",
52
+ "iou",
53
+ "o",
54
+ "ong",
55
+ "ou",
56
+ "u",
57
+ "ua",
58
+ "uai",
59
+ "uan",
60
+ "uang",
61
+ "uei",
62
+ "uen",
63
+ "ueng",
64
+ "uo",
65
+ "v",
66
+ "van",
67
+ "ve",
68
+ "vn",
69
+ ]
70
+
71
+ symbols = _pause + _initials + [i + j for i in _finals for j in _tones]
SongBloom/models/base/sample.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def multinomial(input: torch.Tensor, num_samples: int, replacement=False, *, generator=None):
5
+ """torch.multinomial with arbitrary number of dimensions, and number of candidates on the last dimension.
6
+
7
+ Args:
8
+ input (torch.Tensor): The input tensor containing probabilities.
9
+ num_samples (int): Number of samples to draw.
10
+ replacement (bool): Whether to draw with replacement or not.
11
+ Keywords args:
12
+ generator (torch.Generator): A pseudorandom number generator for sampling.
13
+ Returns:
14
+ torch.Tensor: Last dimension contains num_samples indices
15
+ sampled from the multinomial probability distribution
16
+ located in the last dimension of tensor input.
17
+ """
18
+ input_ = input.reshape(-1, input.shape[-1])
19
+ output_ = torch.multinomial(input_, num_samples=num_samples, replacement=replacement, generator=generator)
20
+ output = output_.reshape(*list(input.shape[:-1]), -1)
21
+ return output
22
+
23
+
24
+ def sample_top_k(probs: torch.Tensor, k: int) -> torch.Tensor:
25
+ """Sample next token from top K values along the last dimension of the input probs tensor.
26
+
27
+ Args:
28
+ probs (torch.Tensor): Input probabilities with token candidates on the last dimension.
29
+ k (int): The k in “top-k”.
30
+ Returns:
31
+ torch.Tensor: Sampled tokens.
32
+ """
33
+ top_k_value, _ = torch.topk(probs, k, dim=-1)
34
+ min_value_top_k = top_k_value[..., [-1]]
35
+ probs *= (probs >= min_value_top_k).float()
36
+ probs.div_(probs.sum(dim=-1, keepdim=True))
37
+ next_token = multinomial(probs, num_samples=1)
38
+ return next_token
39
+
40
+
41
+ def sample_top_p(probs: torch.Tensor, p: float) -> torch.Tensor:
42
+ """Sample next token from top P probabilities along the last dimension of the input probs tensor.
43
+
44
+ Args:
45
+ probs (torch.Tensor): Input probabilities with token candidates on the last dimension.
46
+ p (int): The p in “top-p”.
47
+ Returns:
48
+ torch.Tensor: Sampled tokens.
49
+ """
50
+ probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
51
+ probs_sum = torch.cumsum(probs_sort, dim=-1)
52
+ mask = probs_sum - probs_sort > p
53
+ probs_sort *= (~mask).float()
54
+ probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
55
+ next_token = multinomial(probs_sort, num_samples=1)
56
+ next_token = torch.gather(probs_idx, -1, next_token)
57
+ return next_token
SongBloom/models/base/utils.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import typing as tp
4
+
5
+ def length_to_mask(lengths: torch.Tensor, max_len: tp.Optional[int] = None) -> torch.Tensor:
6
+ """Utility function to convert a tensor of sequence lengths to a mask (useful when working on padded sequences).
7
+ For example: [3, 5] => [[1, 1, 1, 0, 0], [1, 1, 1, 1, 1]]
8
+
9
+ Args:
10
+ lengths (torch.Tensor): tensor with lengths
11
+ max_len (int): can set the max length manually. Defaults to None.
12
+ Returns:
13
+ torch.Tensor: mask with 0s where there is pad tokens else 1s
14
+ """
15
+ assert len(lengths.shape) == 1, "Length shape should be 1 dimensional."
16
+ final_length = lengths.max().item() if not max_len else max_len
17
+ final_length = max(final_length, 1) # if all seqs are of len zero we don't want a zero-size tensor
18
+ return torch.arange(final_length)[None, :].to(lengths.device) < lengths[:, None]
19
+
20
+
21
+ def create_sin_embedding(positions: torch.Tensor, dim: int, max_period: float = 10000,
22
+ dtype: torch.dtype = torch.float32) -> torch.Tensor:
23
+ """Create sinusoidal positional embedding, with shape `[B, T, C]`.
24
+
25
+ Args:
26
+ positions (torch.Tensor): LongTensor of positions.
27
+ dim (int): Dimension of the embedding.
28
+ max_period (float): Maximum period of the cosine/sine functions.
29
+ dtype (torch.dtype or str): dtype to use to generate the embedding.
30
+ Returns:
31
+ torch.Tensor: Sinusoidal positional embedding.
32
+ """
33
+ # We aim for BTC format
34
+ assert dim % 2 == 0
35
+ half_dim = dim // 2
36
+ positions = positions.to(dtype)
37
+ adim = torch.arange(half_dim, device=positions.device, dtype=dtype).view(1, 1, -1)
38
+ max_period_tensor = torch.full([], max_period, device=positions.device, dtype=dtype) # avoid sync point
39
+ phase = positions / (max_period_tensor ** (adim / (half_dim - 1)))
40
+ # phase = phase.to(torch.bfloat16)
41
+ return torch.cat([torch.cos(phase), torch.sin(phase)], dim=-1)
42
+
43
+
44
+ def create_norm_fn(norm_type: str, dim: int, **kwargs) -> nn.Module:
45
+ """Create normalization module for transformer encoder layer.
46
+
47
+ Args:
48
+ norm_type (str): Normalization method.
49
+ dim (int): Dimension of the normalized layer.
50
+ **kwargs (dict): Additional parameters for normalization layer.
51
+ Returns:
52
+ nn.Module: Normalization module.
53
+ """
54
+ if norm_type == 'layer_norm':
55
+ return nn.LayerNorm(dim, eps=1e-5, **kwargs)
56
+ else:
57
+ raise ValueError(f"Unknown norm type: {norm_type}")
SongBloom/models/musicgen/__init__.py ADDED
File without changes
SongBloom/models/musicgen/conditioners/__init__.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import omegaconf
8
+ from .base import *
9
+ from .text import *
10
+ from .wav import *
11
+
12
+ KLASS = {
13
+ 'phoneme_tokenizer': PhonemeTokenizerConditioner,
14
+ 'audio_tokenizer_wrapper': AudioTokenizerConditioner,
15
+ }
16
+
17
+ def get_condition_fuser(fuser_cfgs) -> ConditionFuser:
18
+ """Instantiate a condition fuser object."""
19
+ fuser_methods = ['sum', 'cross', 'prepend', 'input_interpolate']
20
+ fuse2cond = {k: fuser_cfgs[k] for k in fuser_methods}
21
+ kwargs = {k: v for k, v in fuser_cfgs.items() if k not in fuser_methods}
22
+ fuser = ConditionFuser(fuse2cond=fuse2cond, **kwargs)
23
+ return fuser
24
+
25
+ def get_conditioner_provider(cfg) -> ConditioningProvider:
26
+ """Instantiate a conditioning model."""
27
+
28
+ dict_cfg = {} if cfg is None else dict(cfg)
29
+ conditioners: tp.Dict[str, BaseConditioner] = {}
30
+
31
+ # import pdb; pdb.set_trace()
32
+ for cond, cond_cfg in dict_cfg.items():
33
+ model_args = cond_cfg.copy()
34
+ model_type = model_args.pop('type')
35
+ conditioners[str(cond)] = KLASS[model_type](**model_args)
36
+ conditioner = ConditioningProvider(conditioners)
37
+ return conditioner
SongBloom/models/musicgen/conditioners/base.py ADDED
@@ -0,0 +1,872 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import defaultdict
2
+ from copy import deepcopy
3
+ from dataclasses import dataclass, field
4
+ from itertools import chain
5
+ import logging
6
+ import typing as tp
7
+ import einops
8
+
9
+ import torch
10
+ from torch import nn
11
+ import torch.nn.functional as F
12
+ from torch.nn.utils.rnn import pad_sequence
13
+
14
+ from dataclasses import dataclass, field, fields, replace
15
+
16
+ from ..modules.streaming import StreamingModule
17
+ from ...base.utils import length_to_mask, create_sin_embedding
18
+
19
+
20
+ def collate(tensors: tp.List[torch.Tensor], dim: int = 0) -> tp.Tuple[torch.Tensor, torch.Tensor]:
21
+ """Get a list of tensors and collate them to a single tensor. according to the following logic:
22
+ - `dim` specifies the time dimension which will be stacked and padded.
23
+ - The output will contain 1 new dimension (dimension index 0) which will be the size of
24
+ of the original list.
25
+
26
+ Args:
27
+ tensors (tp.List[torch.Tensor]): List of tensors to collate.
28
+ dim (int): Dimension which will be stacked and padded.
29
+ Returns:
30
+ tp.Tuple[torch.Tensor, torch.Tensor]:
31
+ torch.Tensor: Stacked and padded tensor. The output will contain 1 new dimension
32
+ (dimension index 0) which will be the size of the original list.
33
+ torch.Tensor: Tensor containing length of original tensor sizes (without padding).
34
+ """
35
+ tensors = [x.transpose(0, dim) for x in tensors]
36
+ lens = torch.LongTensor([len(x) for x in tensors])
37
+ padded_tensors = pad_sequence(tensors)
38
+ padded_tensors = padded_tensors.transpose(0, 1)
39
+ padded_tensors = padded_tensors.transpose(1, dim + 1)
40
+ return padded_tensors, lens
41
+
42
+
43
+
44
+ @dataclass(order=True)
45
+ class PathInZip:
46
+ """Hold a path of file within a zip file.
47
+
48
+ Args:
49
+ path (str): The convention is <path_to_zip>:<relative_path_inside_zip>.
50
+ Let's assume there is a zip file /some/location/foo.zip
51
+ and inside of it is a json file located at /data/file1.json,
52
+ Then we expect path = "/some/location/foo.zip:/data/file1.json".
53
+ """
54
+
55
+ INFO_PATH_SEP = ':'
56
+ zip_path: str
57
+ file_path: str
58
+
59
+ def __init__(self, path: str) -> None:
60
+ split_path = path.split(self.INFO_PATH_SEP)
61
+ assert len(split_path) == 2
62
+ self.zip_path, self.file_path = split_path
63
+
64
+ @classmethod
65
+ def from_paths(cls, zip_path: str, file_path: str):
66
+ return cls(zip_path + cls.INFO_PATH_SEP + file_path)
67
+
68
+ def __str__(self) -> str:
69
+ return self.zip_path + self.INFO_PATH_SEP + self.file_path
70
+
71
+
72
+ @dataclass(order=True)
73
+ class BaseInfo:
74
+
75
+ @classmethod
76
+ def _dict2fields(cls, dictionary: dict):
77
+ return {
78
+ field.name: dictionary[field.name]
79
+ for field in fields(cls) if field.name in dictionary
80
+ }
81
+ # try:
82
+ # return {
83
+ # field.name: dictionary[field.name]
84
+ # for field in fields(cls) if field.name in dictionary
85
+ # }
86
+ # except:
87
+ # print(dictionary)
88
+
89
+ @classmethod
90
+ def from_dict(cls, dictionary: dict):
91
+ _dictionary = cls._dict2fields(dictionary)
92
+ return cls(**_dictionary)
93
+
94
+ def to_dict(self):
95
+ return {
96
+ field.name: self.__getattribute__(field.name)
97
+ for field in fields(self)
98
+ }
99
+
100
+
101
+ @dataclass(order=True)
102
+ class AudioMeta(BaseInfo):
103
+ path: str
104
+ duration: float
105
+ sample_rate: int
106
+ amplitude: tp.Optional[float] = None
107
+ weight: tp.Optional[float] = None
108
+ # info_path is used to load additional information about the audio file that is stored in zip files.
109
+ info_path: tp.Optional[PathInZip] = None
110
+
111
+ @classmethod
112
+ def from_dict(cls, dictionary: dict):
113
+ base = cls._dict2fields(dictionary)
114
+ if 'info_path' in base and base['info_path'] is not None:
115
+ base['info_path'] = PathInZip(base['info_path'])
116
+ return cls(**base)
117
+
118
+ def to_dict(self):
119
+ d = super().to_dict()
120
+ if d['info_path'] is not None:
121
+ d['info_path'] = str(d['info_path'])
122
+ return d
123
+
124
+
125
+ @dataclass(order=True)
126
+ class SegmentInfo(BaseInfo):
127
+ meta: AudioMeta
128
+ seek_time: float
129
+ # The following values are given once the audio is processed, e.g.
130
+ # at the target sample rate and target number of channels.
131
+ n_frames: int # actual number of frames without padding
132
+ total_frames: int # total number of frames, padding included
133
+ sample_rate: int # actual sample rate
134
+ channels: int # number of audio channels.
135
+
136
+
137
+ logger = logging.getLogger(__name__)
138
+ TextCondition = tp.Optional[str] # a text condition can be a string or None (if doesn't exist)
139
+ ConditionType = tp.Tuple[torch.Tensor, torch.Tensor] # condition, mask
140
+
141
+
142
+ class WavCondition(tp.NamedTuple):
143
+ wav: torch.Tensor
144
+ length: torch.Tensor
145
+ sample_rate: tp.List[int]
146
+ path: tp.List[tp.Optional[str]] = []
147
+ seek_time: tp.List[tp.Optional[float]] = []
148
+
149
+
150
+ class JointEmbedCondition(tp.NamedTuple):
151
+ wav: torch.Tensor
152
+ text: tp.List[tp.Optional[str]]
153
+ length: torch.Tensor
154
+ sample_rate: tp.List[int]
155
+ path: tp.List[tp.Optional[str]] = []
156
+ seek_time: tp.List[tp.Optional[float]] = []
157
+
158
+
159
+ @dataclass
160
+ class ConditioningAttributes:
161
+ text: tp.Dict[str, tp.Optional[str]] = field(default_factory=dict)
162
+ wav: tp.Dict[str, WavCondition] = field(default_factory=dict)
163
+ joint_embed: tp.Dict[str, JointEmbedCondition] = field(default_factory=dict)
164
+
165
+ def __getitem__(self, item):
166
+ return getattr(self, item)
167
+
168
+ @property
169
+ def text_attributes(self):
170
+ return self.text.keys()
171
+
172
+ @property
173
+ def wav_attributes(self):
174
+ return self.wav.keys()
175
+
176
+ @property
177
+ def joint_embed_attributes(self):
178
+ return self.joint_embed.keys()
179
+
180
+ @property
181
+ def attributes(self):
182
+ return {
183
+ "text": self.text_attributes,
184
+ "wav": self.wav_attributes,
185
+ "joint_embed": self.joint_embed_attributes,
186
+ }
187
+
188
+ def to_flat_dict(self):
189
+ return {
190
+ **{f"text.{k}": v for k, v in self.text.items()},
191
+ **{f"wav.{k}": v for k, v in self.wav.items()},
192
+ **{f"joint_embed.{k}": v for k, v in self.joint_embed.items()}
193
+ }
194
+
195
+ @classmethod
196
+ def from_flat_dict(cls, x):
197
+ out = cls()
198
+ for k, v in x.items():
199
+ kind, att = k.split(".")
200
+ out[kind][att] = v
201
+ return out
202
+
203
+
204
+
205
+ # class SegmentWithAttributes(SegmentInfo):
206
+ # """Base class for all dataclasses that are used for conditioning.
207
+ # All child classes should implement `to_condition_attributes` that converts
208
+ # the existing attributes to a dataclass of type ConditioningAttributes.
209
+ # """
210
+ # def to_condition_attributes(self) -> ConditioningAttributes:
211
+ # raise NotImplementedError()
212
+
213
+
214
+
215
+ def nullify_condition(condition: ConditionType, dim: int = 1):
216
+ """Transform an input condition to a null condition.
217
+ The way it is done by converting it to a single zero vector similarly
218
+ to how it is done inside WhiteSpaceTokenizer and NoopTokenizer.
219
+
220
+ Args:
221
+ condition (ConditionType): A tuple of condition and mask (tuple[torch.Tensor, torch.Tensor])
222
+ dim (int): The dimension that will be truncated (should be the time dimension)
223
+ WARNING!: dim should not be the batch dimension!
224
+ Returns:
225
+ ConditionType: A tuple of null condition and mask
226
+ """
227
+ assert dim != 0, "dim cannot be the batch dimension!"
228
+ assert isinstance(condition, tuple) and \
229
+ isinstance(condition[0], torch.Tensor) and \
230
+ isinstance(condition[1], torch.Tensor), "'nullify_condition' got an unexpected input type!"
231
+ cond, mask = condition
232
+ B = cond.shape[0]
233
+ last_dim = cond.dim() - 1
234
+ out = cond.transpose(dim, last_dim)
235
+ out = 0. * out[..., :1]
236
+ out = out.transpose(dim, last_dim)
237
+ mask = torch.zeros((B, 1), device=out.device).int()
238
+ assert cond.dim() == out.dim()
239
+ return out, mask
240
+
241
+
242
+ def nullify_wav(cond: WavCondition) -> WavCondition:
243
+ """Transform a WavCondition to a nullified WavCondition.
244
+ It replaces the wav by a null tensor, forces its length to 0, and replaces metadata by dummy attributes.
245
+
246
+ Args:
247
+ cond (WavCondition): Wav condition with wav, tensor of shape [B, T].
248
+ Returns:
249
+ WavCondition: Nullified wav condition.
250
+ """
251
+ #TODO by YCY, fix this to support zero-length input (as None)
252
+ null_wav, _ = nullify_condition((cond.wav, torch.zeros_like(cond.wav)), dim=cond.wav.dim() - 1) # B,1 all-zero
253
+ return WavCondition(
254
+ wav=null_wav,
255
+ length=torch.tensor([0] * cond.wav.shape[0], device=cond.wav.device),
256
+ sample_rate=cond.sample_rate,
257
+ path=[None] * cond.wav.shape[0],
258
+ seek_time=[None] * cond.wav.shape[0],
259
+ )
260
+
261
+
262
+ def nullify_joint_embed(embed: JointEmbedCondition) -> JointEmbedCondition:
263
+ """Nullify the joint embedding condition by replacing it by a null tensor, forcing its length to 0,
264
+ and replacing metadata by dummy attributes.
265
+
266
+ Args:
267
+ cond (JointEmbedCondition): Joint embedding condition with wav and text, wav tensor of shape [B, C, T].
268
+ """
269
+ null_wav, _ = nullify_condition((embed.wav, torch.zeros_like(embed.wav)), dim=embed.wav.dim() - 1)
270
+ return JointEmbedCondition(
271
+ wav=null_wav, text=[None] * len(embed.text),
272
+ length=torch.LongTensor([0]).to(embed.wav.device),
273
+ sample_rate=embed.sample_rate,
274
+ path=[None] * embed.wav.shape[0],
275
+ seek_time=[0] * embed.wav.shape[0],
276
+ )
277
+
278
+
279
+
280
+ class BaseConditioner(nn.Module):
281
+ """Base model for all conditioner modules.
282
+ We allow the output dim to be different than the hidden dim for two reasons:
283
+ 1) keep our LUTs small when the vocab is large;
284
+ 2) make all condition dims consistent.
285
+
286
+ Args:
287
+ dim (int): Hidden dim of the model.
288
+ output_dim (int): Output dim of the conditioner.
289
+ """
290
+ def __init__(self, dim: int, output_dim: int, input_token = False, padding_idx=None):
291
+ super().__init__()
292
+ self.dim = dim
293
+ self.output_dim = output_dim
294
+ if input_token:
295
+ self.output_proj = nn.Embedding(dim, output_dim, padding_idx)
296
+ else:
297
+ self.output_proj = nn.Linear(dim, output_dim)
298
+
299
+ def tokenize(self, *args, **kwargs) -> tp.Any:
300
+ """Should be any part of the processing that will lead to a synchronization
301
+ point, e.g. BPE tokenization with transfer to the GPU.
302
+
303
+ The returned value will be saved and return later when calling forward().
304
+ """
305
+ raise NotImplementedError()
306
+
307
+ def forward(self, inputs: tp.Any) -> ConditionType:
308
+ """Gets input that should be used as conditioning (e.g, genre, description or a waveform).
309
+ Outputs a ConditionType, after the input data was embedded as a dense vector.
310
+
311
+ Returns:
312
+ ConditionType:
313
+ - A tensor of size [B, T, D] where B is the batch size, T is the length of the
314
+ output embedding and D is the dimension of the embedding.
315
+ - And a mask indicating where the padding tokens.
316
+ """
317
+ raise NotImplementedError()
318
+
319
+
320
+
321
+ def dropout_condition(sample: ConditioningAttributes, condition_type: str, condition: str) -> ConditioningAttributes:
322
+ """Utility function for nullifying an attribute inside an ConditioningAttributes object.
323
+ If the condition is of type "wav", then nullify it using `nullify_condition` function.
324
+ If the condition is of any other type, set its value to None.
325
+ Works in-place.
326
+ """
327
+ if condition_type not in ['text', 'wav', 'joint_embed']:
328
+ raise ValueError(
329
+ "dropout_condition got an unexpected condition type!"
330
+ f" expected 'text', 'wav' or 'joint_embed' but got '{condition_type}'"
331
+ )
332
+
333
+ if condition not in getattr(sample, condition_type):
334
+ raise ValueError(
335
+ "dropout_condition received an unexpected condition!"
336
+ f" expected wav={sample.wav.keys()} and text={sample.text.keys()}"
337
+ f" but got '{condition}' of type '{condition_type}'!"
338
+ )
339
+
340
+ if condition_type == 'wav':
341
+ wav_cond = sample.wav[condition]
342
+ sample.wav[condition] = nullify_wav(wav_cond)
343
+ elif condition_type == 'joint_embed':
344
+ embed = sample.joint_embed[condition]
345
+ sample.joint_embed[condition] = nullify_joint_embed(embed)
346
+ else:
347
+ sample.text[condition] = None
348
+
349
+ return sample
350
+
351
+
352
+ class DropoutModule(nn.Module):
353
+ """Base module for all dropout modules."""
354
+ def __init__(self, seed: int = 1234):
355
+ super().__init__()
356
+ self.rng = torch.Generator()
357
+ self.rng.manual_seed(seed)
358
+
359
+
360
+ class AttributeDropout(DropoutModule):
361
+ """Dropout with a given probability per attribute.
362
+ This is different from the behavior of ClassifierFreeGuidanceDropout as this allows for attributes
363
+ to be dropped out separately. For example, "artist" can be dropped while "genre" remains.
364
+ This is in contrast to ClassifierFreeGuidanceDropout where if "artist" is dropped "genre"
365
+ must also be dropped.
366
+
367
+ Args:
368
+ p (tp.Dict[str, float]): A dict mapping between attributes and dropout probability. For example:
369
+ ...
370
+ "genre": 0.1,
371
+ "artist": 0.5,
372
+ "wav": 0.25,
373
+ ...
374
+ active_on_eval (bool, optional): Whether the dropout is active at eval. Default to False.
375
+ seed (int, optional): Random seed.
376
+ """
377
+ def __init__(self, p: tp.Dict[str, tp.Dict[str, float]], active_on_eval: bool = False, seed: int = 1234):
378
+ super().__init__(seed=seed)
379
+ self.active_on_eval = active_on_eval
380
+ # construct dict that return the values from p otherwise 0
381
+ self.p = {}
382
+ for condition_type, probs in p.items():
383
+ self.p[condition_type] = defaultdict(lambda: 0, probs)
384
+
385
+ def forward(self, samples: tp.List[ConditioningAttributes]) -> tp.List[ConditioningAttributes]:
386
+ """
387
+ Args:
388
+ samples (list[ConditioningAttributes]): List of conditions.
389
+ Returns:
390
+ list[ConditioningAttributes]: List of conditions after certain attributes were set to None.
391
+ """
392
+ if not self.training and not self.active_on_eval:
393
+ return samples
394
+
395
+ samples = deepcopy(samples)
396
+ for condition_type, ps in self.p.items(): # for condition types [text, wav]
397
+ for condition, p in ps.items(): # for attributes of each type (e.g., [artist, genre])
398
+ # import pdb; pdb.set_trace()
399
+ # print(condition, p)
400
+ if torch.rand(1, generator=self.rng).item() < p:
401
+ for sample in samples:
402
+ dropout_condition(sample, condition_type, condition)
403
+ return samples
404
+
405
+ def __repr__(self):
406
+ return f"AttributeDropout({dict(self.p)})"
407
+
408
+
409
+ class ClassifierFreeGuidanceDropout(DropoutModule):
410
+ """Classifier Free Guidance dropout.
411
+ All attributes are dropped with the same probability.
412
+
413
+ Args:
414
+ p (float): Probability to apply condition dropout during training.
415
+ seed (int): Random seed.
416
+ """
417
+ def __init__(self, p: float, seed: int = 1234):
418
+ super().__init__(seed=seed)
419
+ self.p = p
420
+
421
+ def forward(self, samples: tp.List[ConditioningAttributes]) -> tp.List[ConditioningAttributes]:
422
+ """
423
+ Args:
424
+ samples (list[ConditioningAttributes]): List of conditions.
425
+ Returns:
426
+ list[ConditioningAttributes]: List of conditions after all attributes were set to None.
427
+ """
428
+
429
+ if not self.training:
430
+ return samples
431
+ # import pdb; pdb.set_trace()
432
+ # decide on which attributes to drop in a batched fashion
433
+ drop = torch.rand(1, generator=self.rng).item() < self.p
434
+ if not drop:
435
+ return samples
436
+
437
+ # nullify conditions of all attributes
438
+ samples = deepcopy(samples)
439
+ for condition_type in ["text", "wav","joint_embed"]:
440
+ for sample in samples:
441
+ for condition in sample.attributes[condition_type]:
442
+ dropout_condition(sample, condition_type, condition)
443
+ return samples
444
+
445
+ def __repr__(self):
446
+ return f"ClassifierFreeGuidanceDropout(p={self.p})"
447
+
448
+
449
+ class TextConditioner(BaseConditioner):
450
+ ...
451
+
452
+
453
+ class WaveformConditioner(BaseConditioner):
454
+ """Base class for all conditioners that take a waveform as input.
455
+ Classes that inherit must implement `_get_wav_embedding` that outputs
456
+ a continuous tensor, and `_downsampling_factor` that returns the down-sampling
457
+ factor of the embedding model.
458
+
459
+ Args:
460
+ dim (int): The internal representation dimension.
461
+ output_dim (int): Output dimension.
462
+ """
463
+ def __init__(self, dim: int, output_dim: int, input_token = False, padding_idx=None):
464
+ super().__init__(dim, output_dim, input_token, padding_idx)
465
+
466
+ def tokenize(self, x: WavCondition) -> WavCondition:
467
+ wav, length, sample_rate, path, seek_time = x
468
+ assert length is not None
469
+ return WavCondition(wav, length, sample_rate, path, seek_time)
470
+
471
+ def _get_wav_embedding(self, x: WavCondition) -> torch.Tensor:
472
+ """Gets as input a WavCondition and returns a dense embedding."""
473
+ raise NotImplementedError()
474
+
475
+ def _downsampling_factor(self):
476
+ """Returns the downsampling factor of the embedding model."""
477
+ raise NotImplementedError()
478
+
479
+ def forward(self, x: WavCondition) -> ConditionType:
480
+ """Extract condition embedding and mask from a waveform and its metadata.
481
+ Args:
482
+ x (WavCondition): Waveform condition containing raw waveform and metadata.
483
+ Returns:
484
+ ConditionType: a dense vector representing the conditioning along with its mask
485
+ """
486
+
487
+ wav, lengths, *_ = x
488
+ # import pdb; pdb.set_trace()
489
+ with torch.no_grad():
490
+ embeds = self._get_wav_embedding(x)
491
+ embeds = embeds.to(self.output_proj.weight)
492
+ embeds = self.output_proj(embeds)
493
+ # import pdb; pdb.set_trace()
494
+ if lengths is not None:
495
+ lengths = lengths / self._downsampling_factor()
496
+ mask = length_to_mask(lengths, max_len=embeds.shape[1]).int() # type: ignore
497
+ else:
498
+ mask = torch.ones_like(embeds)
499
+ embeds = (embeds * mask.unsqueeze(2))
500
+
501
+ return embeds, mask
502
+
503
+
504
+ class JointEmbeddingConditioner(BaseConditioner):
505
+ """Joint embedding conditioning supporting both audio or text conditioning.
506
+
507
+ Args:
508
+ dim (int): Dimension.
509
+ output_dim (int): Output dimension.
510
+ autocast_dtype (str): Autocast for the conditioner.
511
+ quantize (bool): Whether to quantize the CLAP embedding.
512
+ n_q (int): Number of residual quantizers (used if quantize is true).
513
+ bins (int): Quantizers' codebooks size (used if quantize is true).
514
+ kwargs: Additional parameters for residual vector quantizer.
515
+ """
516
+ def __init__(self, dim: int, output_dim: int,
517
+ autocast_dtype: tp.Optional[str] = 'float32', #quantize: bool = False,
518
+ **kwargs):
519
+ super().__init__(dim=dim, output_dim=output_dim)
520
+ self.autocast_dtype = getattr(torch, autocast_dtype) if autocast_dtype is not None \
521
+ else None
522
+ if self.autocast_dtype is None:
523
+ logger.warning("JointEmbeddingConditioner has no autocast, this might lead to NaN.")
524
+
525
+ # # residual vector quantizer to discretize the conditioned embedding
526
+ # self.quantizer = None
527
+ # if quantize:
528
+ # from ..modules.quantization import ResidualVectorQuantizer
529
+ # self.quantizer = ResidualVectorQuantizer(dim, n_q=n_q, bins=bins, **kwargs)
530
+
531
+ def _get_embed(self, x: JointEmbedCondition) -> tp.Tuple[torch.Tensor, torch.Tensor]:
532
+ """Get joint embedding in latent space from the inputs.
533
+
534
+ Returns:
535
+ tuple[torch.Tensor, torch.Tensor]: Tensor for the latent embedding
536
+ and corresponding empty indexes.
537
+ """
538
+ raise NotImplementedError()
539
+
540
+ def forward(self, x: JointEmbedCondition) -> ConditionType:
541
+ with torch.cuda.amp.autocast(dtype=self.autocast_dtype):
542
+ embed, empty_idx = self._get_embed(x)
543
+ if self.quantizer is not None:
544
+ embed = embed.view(-1, self.dim, 1)
545
+ q_res = self.quantizer(embed, frame_rate=1)
546
+ out_embed = q_res.x.view(-1, self.dim)
547
+ else:
548
+ out_embed = embed
549
+ out_embed = self.output_proj(out_embed).view(-1, 1, self.output_dim)
550
+ mask = torch.ones(*out_embed.shape[:2], device=out_embed.device)
551
+ mask[empty_idx, :] = 0 # zero-out index where the input is non-existant
552
+ out_embed = (out_embed * mask.unsqueeze(-1))
553
+ return out_embed, mask
554
+
555
+ def tokenize(self, x: JointEmbedCondition) -> JointEmbedCondition:
556
+ return x
557
+
558
+
559
+ class ConditioningProvider(nn.Module):
560
+ """Prepare and provide conditions given all the supported conditioners.
561
+
562
+ Args:
563
+ conditioners (dict): Dictionary of conditioners.
564
+ """
565
+ def __init__(self, conditioners: tp.Dict[str, BaseConditioner]):
566
+ super().__init__()
567
+ self.conditioners = nn.ModuleDict(conditioners)
568
+ def _check_conditioner_type(c):
569
+ if isinstance(c, WaveformConditioner):
570
+ return "wav"
571
+ elif isinstance(c, TextConditioner):
572
+ return "text"
573
+ elif isinstance(c, JointEmbeddingConditioner):
574
+ return "joint_embed"
575
+ else:
576
+ raise NotImplementedError(f"{type(c)} are not Implemented!")
577
+ self.conditioner_type = {k: _check_conditioner_type(self.conditioners[k]) for k in self.conditioners}
578
+
579
+
580
+ @property
581
+ def joint_embed_conditions(self):
582
+ return [k for k, v in self.conditioners.items() if isinstance(v, JointEmbeddingConditioner)]
583
+
584
+ @property
585
+ def has_joint_embed_conditions(self):
586
+ return len(self.joint_embed_conditions) > 0
587
+
588
+ @property
589
+ def text_conditions(self):
590
+ return [k for k, v in self.conditioners.items() if isinstance(v, TextConditioner)]
591
+
592
+ @property
593
+ def wav_conditions(self):
594
+ return [k for k, v in self.conditioners.items() if isinstance(v, WaveformConditioner)]
595
+
596
+ @property
597
+ def has_wav_condition(self):
598
+ return len(self.wav_conditions) > 0
599
+
600
+ def tokenize(self, inputs: tp.List[ConditioningAttributes]) -> tp.Dict[str, tp.Any]:
601
+ """Match attributes/wavs with existing conditioners in self, and compute tokenize them accordingly.
602
+ This should be called before starting any real GPU work to avoid synchronization points.
603
+ This will return a dict matching conditioner names to their arbitrary tokenized representations.
604
+
605
+ Args:
606
+ inputs (list[ConditioningAttributes]): List of ConditioningAttributes objects containing
607
+ text and wav conditions.
608
+ """
609
+ assert all([isinstance(x, ConditioningAttributes) for x in inputs]), (
610
+ "Got unexpected types input for conditioner! should be tp.List[ConditioningAttributes]",
611
+ f" but types were {set([type(x) for x in inputs])}"
612
+ )
613
+
614
+ # import pdb; pdb.set_trace()
615
+ output = {}
616
+ text = self._collate_text(inputs)
617
+ wavs = self._collate_wavs(inputs)
618
+ joint_embeds = self._collate_joint_embeds(inputs)
619
+
620
+ assert set(text.keys() | wavs.keys() | joint_embeds.keys()).issubset(set(self.conditioners.keys())), (
621
+ f"Got an unexpected attribute! Expected {self.conditioners.keys()}, ",
622
+ f"got {text.keys(), wavs.keys(), joint_embeds.keys()}"
623
+ )
624
+
625
+ for attribute, batch in chain(text.items(), wavs.items(), joint_embeds.items()):
626
+ output[attribute] = self.conditioners[attribute].tokenize(batch)
627
+ return output
628
+
629
+ def forward(self, tokenized: tp.Dict[str, tp.Any], texts = None) -> tp.Dict[str, ConditionType]:
630
+ """Compute pairs of `(embedding, mask)` using the configured conditioners and the tokenized representations.
631
+ The output is for example:
632
+ {
633
+ "genre": (torch.Tensor([B, 1, D_genre]), torch.Tensor([B, 1])),
634
+ "description": (torch.Tensor([B, T_desc, D_desc]), torch.Tensor([B, T_desc])),
635
+ ...
636
+ }
637
+
638
+ Args:
639
+ tokenized (dict): Dict of tokenized representations as returned by `tokenize()`.
640
+ """
641
+ # import pdb; pdb.set_trace()
642
+ output = {}
643
+ for attribute, inputs in tokenized.items():
644
+ if attribute == 'self_wav' and texts is not None:
645
+ condition, mask = self.conditioners[attribute](inputs, texts = texts)
646
+ else:
647
+ condition, mask = self.conditioners[attribute](inputs)
648
+ output[attribute] = (condition, mask)
649
+ return output
650
+
651
+ def _collate_text(self, samples: tp.List[ConditioningAttributes]) -> tp.Dict[str, tp.List[tp.Optional[str]]]:
652
+ """Given a list of ConditioningAttributes objects, compile a dictionary where the keys
653
+ are the attributes and the values are the aggregated input per attribute.
654
+ For example:
655
+ Input:
656
+ [
657
+ ConditioningAttributes(text={"genre": "Rock", "description": "A rock song with a guitar solo"}, wav=...),
658
+ ConditioningAttributes(text={"genre": "Hip-hop", "description": "A hip-hop verse"}, wav=...),
659
+ ]
660
+ Output:
661
+ {
662
+ "genre": ["Rock", "Hip-hop"],
663
+ "description": ["A rock song with a guitar solo", "A hip-hop verse"]
664
+ }
665
+
666
+ Args:
667
+ samples (list of ConditioningAttributes): List of ConditioningAttributes samples.
668
+ Returns:
669
+ dict[str, list[str, optional]]: A dictionary mapping an attribute name to text batch.
670
+ """
671
+ out: tp.Dict[str, tp.List[tp.Optional[str]]] = defaultdict(list)
672
+ texts = [x.text for x in samples]
673
+ for text in texts:
674
+ for condition in self.text_conditions:
675
+ out[condition].append(text[condition])
676
+ return out
677
+
678
+ def _collate_wavs(self, samples: tp.List[ConditioningAttributes]) -> tp.Dict[str, WavCondition]:
679
+ """Generate a dict where the keys are attributes by which we fetch similar wavs,
680
+ and the values are Tensors of wavs according to said attributes.
681
+
682
+ *Note*: by the time the samples reach this function, each sample should have some waveform
683
+ inside the "wav" attribute. It should be either:
684
+ 1. A real waveform
685
+ 2. A null waveform due to the sample having no similar waveforms (nullified by the dataset)
686
+ 3. A null waveform due to it being dropped in a dropout module (nullified by dropout)
687
+
688
+ Args:
689
+ samples (list of ConditioningAttributes): List of ConditioningAttributes samples.
690
+ Returns:
691
+ dict[str, WavCondition]: A dictionary mapping an attribute name to wavs.
692
+ """
693
+ # import pdb; pdb.set_trace()
694
+ wavs = defaultdict(list)
695
+ lengths = defaultdict(list)
696
+ sample_rates = defaultdict(list)
697
+ paths = defaultdict(list)
698
+ seek_times = defaultdict(list)
699
+ out: tp.Dict[str, WavCondition] = {}
700
+
701
+ for sample in samples:
702
+ for attribute in self.wav_conditions:
703
+ wav, length, sample_rate, path, seek_time = sample.wav[attribute]
704
+ assert wav.dim() == 3, f"Got wav with dim={wav.dim()}, but expected 3 [1, C, T]"
705
+ assert wav.size(0) == 1, f"Got wav [B, C, T] with shape={wav.shape}, but expected B == 1"
706
+ # mono-channel conditioning
707
+ # wav = wav.mean(1, keepdim=True) # [1, 1, T] # by cyy, 为了实现后续功能注释掉了,请手动确保channel=1,or 输入channel 符合预期
708
+ wavs[attribute].append(wav.flatten()) # [C*T]
709
+ lengths[attribute].append(length)
710
+ sample_rates[attribute].extend(sample_rate)
711
+ paths[attribute].extend(path)
712
+ seek_times[attribute].extend(seek_time)
713
+
714
+ # stack all wavs to a single tensor
715
+ for attribute in self.wav_conditions:
716
+ stacked_wav, _ = collate(wavs[attribute], dim=0)
717
+ out[attribute] = WavCondition(
718
+ stacked_wav.unsqueeze(1), torch.cat(lengths[attribute]), sample_rates[attribute],
719
+ paths[attribute], seek_times[attribute])
720
+
721
+ return out
722
+
723
+ def _collate_joint_embeds(self, samples: tp.List[ConditioningAttributes]) -> tp.Dict[str, JointEmbedCondition]:
724
+ """Generate a dict where the keys are attributes by which we compute joint embeddings,
725
+ and the values are Tensors of pre-computed embeddings and the corresponding text attributes.
726
+
727
+ Args:
728
+ samples (list[ConditioningAttributes]): List of ConditioningAttributes samples.
729
+ Returns:
730
+ A dictionary mapping an attribute name to joint embeddings.
731
+ """
732
+ texts = defaultdict(list)
733
+ wavs = defaultdict(list)
734
+ lengths = defaultdict(list)
735
+ sample_rates = defaultdict(list)
736
+ paths = defaultdict(list)
737
+ seek_times = defaultdict(list)
738
+ channels: int = 0
739
+
740
+ out = {}
741
+ for sample in samples:
742
+ for attribute in self.joint_embed_conditions:
743
+ wav, text, length, sample_rate, path, seek_time = sample.joint_embed[attribute]
744
+ assert wav.dim() == 3
745
+ if channels == 0:
746
+ channels = wav.size(1)
747
+ else:
748
+ assert channels == wav.size(1), "not all audio has same number of channels in batch"
749
+ assert wav.size(0) == 1, "Expecting single-wav batch in the collate method"
750
+ wav = einops.rearrange(wav, "b c t -> (b c t)") # [1, C, T] => [C * T]
751
+ wavs[attribute].append(wav)
752
+ texts[attribute].extend(text)
753
+ lengths[attribute].append(length)
754
+ sample_rates[attribute].extend(sample_rate)
755
+ paths[attribute].extend(path)
756
+ seek_times[attribute].extend(seek_time)
757
+
758
+ for attribute in self.joint_embed_conditions:
759
+ stacked_texts = texts[attribute]
760
+ stacked_paths = paths[attribute]
761
+ stacked_seek_times = seek_times[attribute]
762
+ stacked_wavs = pad_sequence(wavs[attribute])
763
+ stacked_wavs = einops.rearrange(stacked_wavs, "(c t) b -> b c t", c=channels)
764
+ stacked_sample_rates = sample_rates[attribute]
765
+ stacked_lengths = torch.cat(lengths[attribute])
766
+
767
+ assert stacked_lengths.size(0) == stacked_wavs.size(0)
768
+ assert len(stacked_sample_rates) == stacked_wavs.size(0)
769
+ assert len(stacked_texts) == stacked_wavs.size(0)
770
+ out[attribute] = JointEmbedCondition(
771
+ text=stacked_texts, wav=stacked_wavs,
772
+ length=stacked_lengths, sample_rate=stacked_sample_rates,
773
+ path=stacked_paths, seek_time=stacked_seek_times)
774
+
775
+ return out
776
+
777
+
778
+ class ConditionFuser(StreamingModule):
779
+ """Condition fuser handles the logic to combine the different conditions
780
+ to the actual model input.
781
+
782
+ Args:
783
+ fuse2cond (tp.Dict[str, str]): A dictionary that says how to fuse
784
+ each condition. For example:
785
+ {
786
+ "prepend": ["description"],
787
+ "sum": ["genre", "bpm"],
788
+ "cross": ["description"],
789
+ }
790
+ cross_attention_pos_emb (bool, optional): Use positional embeddings in cross attention.
791
+ cross_attention_pos_emb_scale (int): Scale for positional embeddings in cross attention if used.
792
+ """
793
+ FUSING_METHODS = ["sum", "prepend", "cross", "input_interpolate"]
794
+
795
+ def __init__(self, fuse2cond: tp.Dict[str, tp.List[str]], cross_attention_pos_emb: bool = False,
796
+ cross_attention_pos_emb_scale: float = 1.0):
797
+ super().__init__()
798
+ assert all(
799
+ [k in self.FUSING_METHODS for k in fuse2cond.keys()]
800
+ ), f"Got invalid fuse method, allowed methods: {self.FUSING_METHODS}"
801
+ self.cross_attention_pos_emb = cross_attention_pos_emb
802
+ self.cross_attention_pos_emb_scale = cross_attention_pos_emb_scale
803
+ self.fuse2cond: tp.Dict[str, tp.List[str]] = fuse2cond
804
+ self.cond2fuse: tp.Dict[str, str] = {}
805
+ for fuse_method, conditions in fuse2cond.items():
806
+ for condition in conditions:
807
+ self.cond2fuse[condition] = fuse_method
808
+
809
+ def forward(
810
+ self,
811
+ input: torch.Tensor,
812
+ conditions: tp.Dict[str, ConditionType]
813
+ ) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
814
+ """Fuse the conditions to the provided model input.
815
+
816
+ Args:
817
+ input (torch.Tensor): Transformer input.
818
+ conditions (dict[str, ConditionType]): Dict of conditions.
819
+ Returns:
820
+ tuple[torch.Tensor, torch.Tensor]: The first tensor is the transformer input
821
+ after the conditions have been fused. The second output tensor is the tensor
822
+ used for cross-attention or None if no cross attention inputs exist.
823
+ """
824
+ # import pdb; pdb.set_trace()
825
+ B, T, _ = input.shape
826
+
827
+ if 'offsets' in self._streaming_state:
828
+ first_step = False
829
+ offsets = self._streaming_state['offsets']
830
+ else:
831
+ first_step = True
832
+ offsets = torch.zeros(input.shape[0], dtype=torch.long, device=input.device)
833
+
834
+ assert set(conditions.keys()).issubset(set(self.cond2fuse.keys())), \
835
+ f"given conditions contain unknown attributes for fuser, " \
836
+ f"expected {self.cond2fuse.keys()}, got {conditions.keys()}"
837
+ cross_attention_output = None
838
+ prepend_input = input[:, :0]
839
+ for cond_type, (cond, cond_mask) in conditions.items():
840
+ op = self.cond2fuse[cond_type]
841
+ if op == 'sum':
842
+ input += cond
843
+ elif op == 'input_interpolate':
844
+ cond = einops.rearrange(cond, "b t d -> b d t")
845
+ cond = F.interpolate(cond, size=input.shape[1])
846
+ input += einops.rearrange(cond, "b d t -> b t d")
847
+ elif op == 'prepend':
848
+ prepend_input = torch.cat([cond.to(input.dtype), prepend_input], dim=1)
849
+ # NOTE 这里cond应该在后,这样顺序才符合配置文件,否则为逆序
850
+ # 但是之前实验是这样的为了保持一致就没改
851
+ elif op == 'cross':
852
+ if cross_attention_output is not None:
853
+ cross_attention_output = torch.cat([cross_attention_output, cond], dim=1)
854
+ else:
855
+ cross_attention_output = cond
856
+ else:
857
+ raise ValueError(f"unknown op ({op})")
858
+
859
+ if self.cross_attention_pos_emb and cross_attention_output is not None:
860
+ positions = torch.arange(
861
+ cross_attention_output.shape[1],
862
+ device=cross_attention_output.device
863
+ ).view(1, -1, 1)
864
+ pos_emb = create_sin_embedding(positions, cross_attention_output.shape[-1])
865
+ cross_attention_output = cross_attention_output + self.cross_attention_pos_emb_scale * pos_emb
866
+
867
+ if first_step:
868
+ input = torch.cat([prepend_input, input], dim=1)
869
+ if self._is_streaming:
870
+ self._streaming_state['offsets'] = offsets + T
871
+
872
+ return input, cross_attention_output
SongBloom/models/musicgen/conditioners/text.py ADDED
@@ -0,0 +1,254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .base import *
2
+
3
+ import spacy
4
+ import warnings
5
+ import random
6
+ import hashlib
7
+ from transformers import RobertaTokenizer, T5EncoderModel, T5Tokenizer, AutoTokenizer, XLMRobertaModel, XLMRobertaTokenizer # type: ignore
8
+ from num2words import num2words
9
+
10
+ def hash_trick(word: str, vocab_size: int) -> int:
11
+ """Hash trick to pair each word with an index
12
+
13
+ Args:
14
+ word (str): word we wish to convert to an index
15
+ vocab_size (int): size of the vocabulary
16
+ Returns:
17
+ int: index of the word in the embedding LUT
18
+ """
19
+
20
+ hash = int(hashlib.sha256(word.encode("utf-8")).hexdigest(), 16)
21
+ return hash % vocab_size
22
+
23
+
24
+
25
+ class PhonemeTokenizerConditioner(TextConditioner):
26
+ def __init__(self,
27
+ output_dim: int,
28
+ vocab_list,
29
+ max_len = 600,
30
+ max_sentence_per_structure = 50,
31
+ structure_tokens=None,
32
+ structure_split_tokens=[','],
33
+ sentence_split_tokens=['.'],
34
+ mode='sum',
35
+ structure_output_dim = 64,
36
+ sentence_output_dim = 64,
37
+ max_duration = 120,
38
+ interpolate = False,
39
+ ):
40
+
41
+ self.vocab_list = vocab_list
42
+ self.max_len = max_len
43
+ self.mode = mode
44
+ self.max_sentence_per_structure = max_sentence_per_structure
45
+ voc_size = len(self.vocab_list)
46
+ self.interpolate = interpolate
47
+
48
+ if structure_tokens is None:
49
+ structure_tokens = [i for i in vocab_list if len(i) > 1 and i[0] == '[' and i[-1] == ']']
50
+ self.structure_token_ids = [vocab_list.index(i) for i in structure_tokens if i in vocab_list]
51
+ self.structure_split_token_ids = [vocab_list.index(i) for i in structure_split_tokens]
52
+ self.sentence_split_token_ids = [vocab_list.index(i) for i in sentence_split_tokens]
53
+
54
+ # here initialize a output_proj (nn.Embedding) layer
55
+ # By default the first vocab is "" (null)
56
+ if mode == 'sum':
57
+ content_output_dim = output_dim
58
+ sentence_output_dim = output_dim
59
+ structure_output_dim = output_dim
60
+ else: # concat
61
+ content_output_dim = output_dim - sentence_output_dim - structure_output_dim # by default
62
+
63
+ super().__init__(voc_size, content_output_dim, input_token=True, padding_idx=0)
64
+ if self.mode != 'sum':
65
+ self.special_emb = nn.Embedding(len(self.structure_token_ids)+len(self.structure_split_token_ids)+len(self.sentence_split_token_ids)+1,
66
+ structure_output_dim, padding_idx=0)
67
+
68
+ self.blank_emb = nn.Parameter(torch.zeros(1, output_dim), requires_grad=False)
69
+
70
+ # the first index is "empty structure" token
71
+ self.sentence_idx_in_structure_emb = nn.Embedding(max_sentence_per_structure, sentence_output_dim, padding_idx=0)
72
+
73
+ # print("max_len", self.max_len)
74
+ print(self.structure_token_ids)
75
+
76
+ self.resolution = max_duration / max_len # e.g., 120 / 600 = 0.2s
77
+ print(self.__class__, f"resolution = {self.resolution}")
78
+
79
+ def tokenize(self, x: tp.List[tp.Optional[str]]) -> tp.Dict[str, torch.Tensor]:
80
+ inputs = []
81
+ for xx in x:
82
+ xx = '' if xx is None else xx
83
+ vocab_id = [self.vocab_list.index(item) for item in xx.split(" ") if item in self.vocab_list]
84
+ inputs.append(torch.tensor(vocab_id).long()) # [T]
85
+ return inputs
86
+
87
+
88
+ def interpolate_with_structure_duration(self, special_tokens, embeds, structure_dur):
89
+ # embeds: [T, N]
90
+ def sec2idx(sec): # convert duration sec to token index
91
+ return int(sec / self.resolution)
92
+
93
+ def target_token_types2list(tokens, target_token_types):
94
+
95
+ is_target_list = torch.any(torch.stack([tokens == i for i in target_token_types], dim=-1), dim=-1)
96
+ is_target_list = torch.where(is_target_list)[0].tolist()
97
+ return is_target_list
98
+
99
+ structure_ids = []
100
+ for (structure, st, et) in structure_dur:
101
+ structure_ids.append([structure, sec2idx(st), sec2idx(et)])
102
+
103
+ """
104
+ interpolate embeddings of each structure according to its duration
105
+ """
106
+ is_structure_list = target_token_types2list(special_tokens, self.structure_token_ids)
107
+ is_structure_list.append(special_tokens.shape[-1])
108
+
109
+ split_tokens = deepcopy(self.structure_split_token_ids)
110
+ split_tokens.extend(self.sentence_split_token_ids)
111
+ # is_split_list = target_token_types2list(special_tokens, split_tokens)
112
+
113
+
114
+ interpolated_embeds = embeds[:is_structure_list[0]]
115
+ for i, st in enumerate(is_structure_list[:-1]):
116
+ # (lorry) Explain "-tmp":
117
+ # All structures are connected with " , " token,
118
+ # " ," is also the final token of each structure except the final one,
119
+ # but here we dont want to interpolate " , " token
120
+ tmp = 1
121
+ if i == len(is_structure_list[:-1]) - 1: # the final structure, no need for "-1"
122
+ tmp = 0
123
+
124
+ # print(st, is_structure_list[i+1]-tmp)
125
+ to_interpolate = embeds[st: is_structure_list[i+1] - tmp]
126
+ interpolate_size = structure_ids[i][2] - structure_ids[i][1] - tmp
127
+ # print(interpolate_size)
128
+
129
+ #import pdb; pdb.set_trace()
130
+ # print(interpolated_embeds.shape, to_interpolate.shape, interpolate_size, )
131
+ if to_interpolate.shape[0] == 0:
132
+ import pdb; pdb.set_trace()
133
+ this_interpolated_embeds = F.interpolate(to_interpolate.unsqueeze(0).transpose(2, 1),
134
+ size=interpolate_size,
135
+ mode='nearest-exact').squeeze(0).transpose(1, 0)
136
+
137
+ if tmp == 1:
138
+ interpolated_embeds = torch.cat((interpolated_embeds, this_interpolated_embeds,
139
+ embeds[is_structure_list[i+1]].unsqueeze(0)), 0)
140
+ else:
141
+ interpolated_embeds = torch.cat((interpolated_embeds, this_interpolated_embeds), 0)
142
+ return interpolated_embeds
143
+
144
+
145
+ def forward(self, batch_tokens: tp.List, structure_dur = None) -> ConditionType:
146
+ """
147
+ Encode token_id into three types of embeddings:
148
+ 1) content embedding: phoneme only (or meaningful contents to be sung out)
149
+ 2) structure embedding: structure / separation embeddings, including structures (verse/chorus/...), separators (. / ,)
150
+ The two above share the same embedding layer, can be changed to separate embedding layers.
151
+ 3) sentence_idx embedding (per structure):
152
+ """
153
+ embeds_batch = []
154
+ # print(batch_tokens)
155
+ for b in range(len(batch_tokens)):
156
+ tokens = batch_tokens[b]
157
+
158
+ content_tokens = torch.zeros_like(tokens)
159
+ special_tokens = torch.zeros_like(tokens)
160
+ sentence_idx_in_structure_tokens = torch.zeros_like(tokens)
161
+
162
+ current_structure_idx = 1
163
+ current_sentence_in_structure_idx = 1
164
+ current_structure = 0
165
+
166
+ for i in range(tokens.shape[-1]):
167
+ token = tokens[i]
168
+ if token in self.structure_token_ids: # structure token
169
+ # only update structure token, leave content and sentence index token null (default 0)
170
+ if self.mode == 'sum':
171
+ special_tokens[i] = token
172
+ else:
173
+ special_tokens[i] = self.structure_token_ids.index(token) + 1
174
+ current_structure = token
175
+ current_structure_idx += 1
176
+ current_sentence_in_structure_idx = 1
177
+
178
+ elif token in self.sentence_split_token_ids: # utterance split token
179
+ # only update structure token, leave content and sentence index token null (default 0)
180
+ # add up sentence index
181
+ if self.mode == 'sum':
182
+ special_tokens[i] = token
183
+ else:
184
+ special_tokens[i] = self.sentence_split_token_ids.index(token) + 1 + len(self.structure_token_ids)
185
+ current_sentence_in_structure_idx += 1
186
+
187
+ elif token in self.structure_split_token_ids: # structure split token
188
+ # update structure token (current structure), content token (current token),
189
+ # blank index token
190
+ if self.mode == 'sum':
191
+ special_tokens[i] = token
192
+ else:
193
+ special_tokens[i] = self.structure_split_token_ids.index(token) + 1 + len(self.structure_token_ids) + len(self.sentence_split_token_ids)
194
+
195
+ else: # content tokens
196
+ content_tokens[i] = token
197
+ special_tokens[i] = current_structure
198
+ sentence_idx_in_structure_tokens[i] = min(current_sentence_in_structure_idx, self.max_sentence_per_structure - 1)
199
+
200
+ # print("tokens", tokens.max(), tokens.min())
201
+ # print("special tokens", special_tokens.max(), special_tokens.min())
202
+ # print("sentence idx in structure", sentence_idx_in_structure_tokens.max(), sentence_idx_in_structure_tokens.min())
203
+ device = self.output_proj.weight.device
204
+
205
+ # import pdb; pdb.set_trace()
206
+ content_embeds = self.output_proj(tokens.to(device)) # [T, N]
207
+ if self.mode == 'sum':
208
+ structure_embeds = self.output_proj(special_tokens.to(device))
209
+ else:
210
+ structure_embeds = self.special_emb(special_tokens.to(device))
211
+ sentence_idx_embeds = self.sentence_idx_in_structure_emb(sentence_idx_in_structure_tokens.to(device))
212
+
213
+ if self.mode == 'sum':
214
+ embeds = content_embeds + structure_embeds + sentence_idx_embeds
215
+ else:
216
+ embeds = torch.cat((content_embeds, structure_embeds, sentence_idx_embeds), -1) # [T, N]
217
+
218
+ if self.interpolate:
219
+ embeds = self.interpolate_with_structure_duration(tokens, embeds, structure_dur[b])
220
+ embeds_batch.append(embeds)
221
+
222
+ # set batch_size = 1, [B, T, N]
223
+ if self.max_len is not None:
224
+ max_len = self.max_len
225
+ else:
226
+ max_len = max([e.shape[0] for e in embeds_batch])
227
+ embeds, mask = self.pad_2d_tensor(embeds_batch, max_len)
228
+
229
+ return embeds, mask
230
+
231
+
232
+ def pad_2d_tensor(self, xs, max_len):
233
+ new_tensor = []
234
+ new_mask = []
235
+ for x in xs:
236
+ seq_len, dim = x.size()
237
+ pad_len = max_len - seq_len
238
+
239
+ if pad_len > 0:
240
+ pad_tensor = self.blank_emb.repeat(pad_len, 1).to(x.device) # T, D
241
+ padded_tensor = torch.cat([x, pad_tensor], dim=0)
242
+ mask = torch.cat((torch.ones_like(x[:, 0]),
243
+ torch.zeros_like(pad_tensor[:, 0])), 0) # T
244
+ elif pad_len < 0:
245
+ padded_tensor = x[:max_len]
246
+ mask = torch.ones_like(padded_tensor[:, 0])
247
+ else:
248
+ padded_tensor = x
249
+ mask = torch.ones_like(x[:, 0])
250
+
251
+ new_tensor.append(padded_tensor)
252
+ new_mask.append(mask)
253
+ # [B, T, D] & [B, T]
254
+ return torch.stack(new_tensor, 0), torch.stack(new_mask, 0)
SongBloom/models/musicgen/conditioners/wav.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from .base import *
3
+ import omegaconf
4
+ from ...vae_frontend import AbstractVAE
5
+
6
+ def pad_to_fix_length(x, max_len, pad_value=0.):
7
+ bsz, seq_len = x.shape[:2]
8
+ if seq_len >= max_len:
9
+ return x[:, :max_len]
10
+ else:
11
+ pad_len = max_len - seq_len
12
+ pad_tensor = torch.full((bsz, pad_len, *x.shape[2:]), pad_value, dtype=x.dtype, device=x.device)
13
+ padded_tensor = torch.cat([x, pad_tensor], dim=1)
14
+ return padded_tensor
15
+
16
+ class AudioTokenizerConditioner(WaveformConditioner):
17
+ def __init__(self, output_dim, audio_tokenizer, cache=False, max_len=None):
18
+ super().__init__(output_dim, output_dim)
19
+ self.max_len = max_len
20
+ self.use_cache = cache
21
+
22
+ self.tokenizer = audio_tokenizer
23
+ # breakpoint()
24
+
25
+ # TODO if cached and not load vae, receive a dict instead
26
+ if isinstance(self.tokenizer, dict):
27
+ self.tokenizer = omegaconf.DictConfig(self.tokenizer)
28
+ self.code_depth = self.tokenizer.channel_dim
29
+
30
+
31
+ elif isinstance(self.tokenizer, AbstractVAE):
32
+ self.tokenizer_tp = "vae"
33
+ if self.use_cache:
34
+ self.code_depth = self.tokenizer.channel_dim
35
+ else:
36
+ self.code_depth = 1 # TODO 强制把输入channel设成1了 self.tokenizer.input_channel
37
+ self.output_proj = nn.Identity() if self.output_dim == self.tokenizer.channel_dim \
38
+ else nn.Linear(self.tokenizer.channel_dim, self.output_dim, bias=False)
39
+
40
+ else:
41
+ raise NotImplementedError
42
+
43
+
44
+ def forward(self, x: WavCondition):
45
+ wav, lengths, *_ = x
46
+ B = wav.shape[0]
47
+ wav = wav.reshape(B, self.code_depth, -1)
48
+ # print(wav.shape)
49
+ # import torchaudio
50
+ # torchaudio.save("/apdcephfs_cq7/share_1297902/common/erichtchen/shixisheng/cyy/project/music_generation_repo/core/models/musicgen/conditioners/111.wav", wav[0].cpu(), 48000)
51
+ if self.tokenizer_tp == "vae":
52
+ if self.use_cache:
53
+ audio_latents = wav.transpose(-1,-2)
54
+ else:
55
+ with torch.no_grad():
56
+ audio_latents = self.tokenizer.encode(wav).transpose(-1,-2)
57
+ # print('transform wav to vae')
58
+ audio_latents = self.output_proj(audio_latents)
59
+
60
+ # print(audio_latents.shape)
61
+ if self.max_len is not None:
62
+ audio_latents = pad_to_fix_length(audio_latents, self.max_len, 0.)
63
+
64
+ if lengths is not None:
65
+ lengths = torch.round(lengths.float() * audio_latents.shape[1] / wav.shape[-1])
66
+ mask = length_to_mask(lengths, max_len=audio_latents.shape[1]).int() # type: ignore
67
+ else:
68
+ mask = torch.ones((B, audio_latents.shape[1]), device=audio_latents.device,dtype=torch.int)
69
+
70
+ audio_latents = audio_latents * mask[..., None]
71
+
72
+ return audio_latents, mask
73
+
74
+
SongBloom/models/musicgen/get_backend.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os,sys
3
+ from transformers.utils import is_flash_attn_2_available
4
+ from transformers.models.llama import LlamaModel, LlamaConfig
5
+ from transformers.models.bart.modeling_bart import BartEncoder, BartDecoder, BartConfig
6
+ import warnings
7
+ # from transformers.models.musicgen.modeling_musicgen import MusicgenModel, MusicgenDecoder, MusicgenDecoderConfig # 用的就是BartDecoder,但是没有cross-attn
8
+
9
+ try:
10
+ assert is_flash_attn_2_available()
11
+ assert torch.cuda.get_device_capability(torch.device("cuda")) >= (8, 0)
12
+ assert os.environ.get("DISABLE_FLASH_ATTN",'0') != "1"
13
+ _enable_flash_attention = True
14
+ except:
15
+ _enable_flash_attention = False
16
+
17
+ if not _enable_flash_attention:
18
+ warnings.warn("Not support flash-attn!")
19
+
20
+ def get_backend(name, dim, num_heads, num_layers, hidden_scale, init_std=0.02, rope_theta=10000,):
21
+ # SA (causal) - FF
22
+ if name == 'llama':
23
+ model_cfg = LlamaConfig(
24
+ hidden_size=dim,
25
+ intermediate_size=dim * hidden_scale,
26
+ num_attention_heads=num_heads,
27
+ num_hidden_layers=num_layers,
28
+ num_key_value_heads=num_heads,
29
+ vocab_size=dim,
30
+ use_cache=False,
31
+ max_position_embeddings=4096,
32
+ hidden_act="silu",
33
+ initializer_range=init_std,
34
+ rope_theta=rope_theta,
35
+ _attn_implementation="flash_attention_2" if _enable_flash_attention else "eager",
36
+ )
37
+ model = LlamaModel(model_cfg)
38
+
39
+ # SA -FF
40
+ elif name == 'bart_enc':
41
+ model_cfg = BartConfig(
42
+ d_model=dim,
43
+ max_position_embeddings=4096,
44
+ dropout=0.,
45
+ use_cache=False,
46
+ _attn_implementation="flash_attention_2" if _enable_flash_attention else "eager",
47
+ activation_function='gelu',
48
+ # for BartEncoder
49
+ encoder_layers=num_layers,
50
+ encoder_attention_heads=num_heads,
51
+ init_std=init_std,
52
+ encoder_ffn_dim=dim * hidden_scale,
53
+ )
54
+ model = BartEncoder(model_cfg)
55
+
56
+ # SA - CA - FF
57
+ elif name == 'bart_dec':
58
+ model_cfg = BartConfig(
59
+ d_model=dim,
60
+ max_position_embeddings=4096,
61
+ dropout=0.,
62
+ use_cache=False,
63
+ _attn_implementation="flash_attention_2" if _enable_flash_attention else "eager",
64
+ activation_function='gelu',
65
+ # for BartDecoder
66
+ decoder_layers=num_layers,
67
+ decoder_attention_heads=num_heads,
68
+ decoder_ffn_dim=dim * hidden_scale,
69
+ )
70
+ model = BartDecoder(model_cfg)
71
+
72
+ else:
73
+ raise NotImplementedError
74
+
75
+ delattr(model, "embed_tokens")
76
+ return model
SongBloom/models/musicgen/modules/streaming.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Streaming module API that should be implemented by all Streaming components,
3
+ """
4
+
5
+ from contextlib import contextmanager
6
+ import typing as tp
7
+ from torch import nn
8
+ import torch
9
+
10
+
11
+ State = tp.Dict[str, torch.Tensor]
12
+
13
+
14
+ class StreamingModule(nn.Module):
15
+ """Common API for streaming components.
16
+
17
+ Each streaming component has a streaming state, which is just a dict[str, Tensor].
18
+ By convention, the first dim of each tensor must be the batch size.
19
+ Don't use dots in the key names, as this would clash with submodules
20
+ (like in state_dict).
21
+
22
+ If `self._is_streaming` is True, the component should use and remember
23
+ the proper state inside `self._streaming_state`.
24
+
25
+ To set a streaming component in streaming state, use
26
+
27
+ with module.streaming():
28
+ ...
29
+
30
+ This will automatically reset the streaming state when exiting the context manager.
31
+ This also automatically propagates to all streaming children module.
32
+
33
+ Some module might also implement the `StreamingModule.flush` method, although
34
+ this one is trickier, as all parents module must be StreamingModule and implement
35
+ it as well for it to work properly. See `StreamingSequential` after.
36
+ """
37
+ def __init__(self) -> None:
38
+ super().__init__()
39
+ self._streaming_state: State = {}
40
+ self._is_streaming = False
41
+
42
+ def _apply_named_streaming(self, fn: tp.Any):
43
+ for name, module in self.named_modules():
44
+ if isinstance(module, StreamingModule):
45
+ fn(name, module)
46
+
47
+ def _set_streaming(self, streaming: bool):
48
+ def _set_streaming(name, module):
49
+ module._is_streaming = streaming
50
+ self._apply_named_streaming(_set_streaming)
51
+
52
+ @contextmanager
53
+ def streaming(self):
54
+ """Context manager to enter streaming mode. Reset streaming state on exit."""
55
+ self._set_streaming(True)
56
+ try:
57
+ yield
58
+ finally:
59
+ self._set_streaming(False)
60
+ self.reset_streaming()
61
+
62
+ def reset_streaming(self):
63
+ """Reset the streaming state."""
64
+ def _reset(name: str, module: StreamingModule):
65
+ module._streaming_state.clear()
66
+
67
+ self._apply_named_streaming(_reset)
68
+
69
+ def get_streaming_state(self) -> State:
70
+ """Return the streaming state, including that of sub-modules."""
71
+ state: State = {}
72
+
73
+ def _add(name: str, module: StreamingModule):
74
+ if name:
75
+ name += "."
76
+ for key, value in module._streaming_state.items():
77
+ state[name + key] = value
78
+
79
+ self._apply_named_streaming(_add)
80
+ return state
81
+
82
+ def set_streaming_state(self, state: State):
83
+ """Set the streaming state, including that of sub-modules."""
84
+ state = dict(state)
85
+
86
+ def _set(name: str, module: StreamingModule):
87
+ if name:
88
+ name += "."
89
+ module._streaming_state.clear()
90
+ for key, value in list(state.items()):
91
+ # complexity is not ideal here, but probably fine.
92
+ if key.startswith(name):
93
+ local_key = key[len(name):]
94
+ if '.' not in local_key:
95
+ module._streaming_state[local_key] = value
96
+ del state[key]
97
+
98
+ self._apply_named_streaming(_set)
99
+ assert len(state) == 0, list(state.keys())
100
+
101
+ def flush(self, x: tp.Optional[torch.Tensor] = None):
102
+ """Flush any remaining outputs that were waiting for completion.
103
+ Typically, for convolutions, this will add the final padding
104
+ and process the last buffer.
105
+
106
+ This should take an optional argument `x`, which will be provided
107
+ if a module before this one in the streaming pipeline has already
108
+ spitted out a flushed out buffer.
109
+ """
110
+ if x is None:
111
+ return None
112
+ else:
113
+ return self(x)
114
+
115
+
116
+ class StreamingSequential(StreamingModule, nn.Sequential):
117
+ """A streaming compatible alternative of `nn.Sequential`.
118
+ """
119
+ def flush(self, x: tp.Optional[torch.Tensor] = None):
120
+ for module in self:
121
+ if isinstance(module, StreamingModule):
122
+ x = module.flush(x)
123
+ elif x is not None:
124
+ x = module(x)
125
+ return x
SongBloom/models/musicldm/__init__.py ADDED
File without changes
SongBloom/models/musicldm/inference/__init__.py ADDED
File without changes
SongBloom/models/musicldm/inference/sampling.py ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import math
3
+ from tqdm import trange, tqdm
4
+
5
+ # import k_diffusion as K
6
+
7
+ # Define the noise schedule and sampling loop
8
+ def get_alphas_sigmas(t):
9
+ """Returns the scaling factors for the clean image (alpha) and for the
10
+ noise (sigma), given a timestep."""
11
+ return torch.cos(t * math.pi / 2), torch.sin(t * math.pi / 2)
12
+
13
+ def alpha_sigma_to_t(alpha, sigma):
14
+ """Returns a timestep, given the scaling factors for the clean image and for
15
+ the noise."""
16
+ return torch.atan2(sigma, alpha) / math.pi * 2
17
+
18
+ def t_to_alpha_sigma(t):
19
+ """Returns the scaling factors for the clean image and for the noise, given
20
+ a timestep."""
21
+ return torch.cos(t * math.pi / 2), torch.sin(t * math.pi / 2)
22
+
23
+
24
+ @torch.no_grad()
25
+ def sample_discrete_euler(model, x, steps, sigma_max=1.0, prog_bar=False, **extra_args):
26
+ """Draws samples from a model given starting noise. Euler method"""
27
+
28
+ # Make tensor of ones to broadcast the single t values
29
+ ts = x.new_ones([x.shape[0]])
30
+
31
+ # Create the noise schedule
32
+ t = torch.linspace(sigma_max, 0, steps + 1)
33
+ # all = {}
34
+
35
+ #alphas, sigmas = 1-t, t
36
+ iterator = tqdm(zip(t[:-1], t[1:]), total=steps) if prog_bar else zip(t[:-1], t[1:])
37
+ for t_curr, t_prev in iterator:
38
+ # Broadcast the current timestep to the correct shape
39
+ t_curr_tensor = t_curr * torch.ones(
40
+ (x.shape[0],), dtype=x.dtype, device=x.device
41
+ )
42
+ dt = t_prev - t_curr # we solve backwards in our formulation
43
+ v = model(x, t_curr_tensor, **extra_args)
44
+ # all[t_curr.item()] = x-t_curr*v
45
+ x = x + dt * v #.denoise(x, denoiser, t_curr_tensor, cond, uc)
46
+
47
+ # If we are on the last timestep, output the denoised image
48
+ return x #, all
49
+
50
+ @torch.no_grad()
51
+ def sample_discrete_euler_with_temperature(model, x, steps, temperature=1.0, sigma_max=1.0, prog_bar=False, **extra_args):
52
+ """Draws samples from a model given starting noise. Euler method"""
53
+
54
+ # Make tensor of ones to broadcast the single t values
55
+ ts = x.new_ones([x.shape[0]])
56
+ noise = x
57
+
58
+ # Create the noise schedule
59
+ t = torch.linspace(sigma_max, 0, steps + 1)
60
+ # all = {}
61
+ x = torch.zeros_like(noise)
62
+ if temperature >= sigma_max:
63
+ x = noise
64
+
65
+ #alphas, sigmas = 1-t, t
66
+ iterator = tqdm(zip(t[:-1], t[1:]), total=steps) if prog_bar else zip(t[:-1], t[1:])
67
+ for t_curr, t_prev in iterator:
68
+ # Broadcast the current timestep to the correct shape
69
+
70
+ t_curr_tensor = t_curr * torch.ones(
71
+ (x.shape[0],), dtype=x.dtype, device=x.device
72
+ )
73
+ dt = t_prev - t_curr # we solve backwards in our formulation
74
+ v = model(x, t_curr_tensor, **extra_args)
75
+ # all[t_curr.item()] = x-t_curr*v
76
+ if t_curr > temperature and t_prev <= temperature:
77
+ x_0 = x - v
78
+ x = (1-t_prev) * x_0 + t_prev * noise
79
+ else:
80
+ x = x + dt * v #.denoise(x, denoiser, t_curr_tensor, cond, uc)
81
+
82
+ # If we are on the last timestep, output the denoised image
83
+ return x #, all
84
+
85
+
86
+ @torch.no_grad()
87
+ def sample(model, x, steps, eta, prog_bar=False, **extra_args):
88
+ """Draws samples from a model given starting noise. v-diffusion"""
89
+ ts = x.new_ones([x.shape[0]])
90
+ origin_dtype = x.dtype
91
+ # Create the noise schedule
92
+ t = torch.linspace(1, 0, steps + 1)[:-1]
93
+
94
+ alphas, sigmas = get_alphas_sigmas(t)
95
+
96
+ # The sampling loop
97
+ bar = trange if prog_bar else range
98
+ for i in bar(steps):
99
+
100
+ # Get the model output (v, the predicted velocity)
101
+ with torch.cuda.amp.autocast():
102
+ v = model(x, ts * t[i], **extra_args).float()
103
+
104
+ # Predict the noise and the denoised image
105
+ pred = x * alphas[i] - v * sigmas[i]
106
+ eps = x * sigmas[i] + v * alphas[i]
107
+
108
+ # If we are not on the last timestep, compute the noisy image for the
109
+ # next timestep.
110
+ if i < steps - 1:
111
+ # If eta > 0, adjust the scaling factor for the predicted noise
112
+ # downward according to the amount of additional noise to add
113
+ ddim_sigma = eta * (sigmas[i + 1]**2 / sigmas[i]**2).sqrt() * \
114
+ (1 - alphas[i]**2 / alphas[i + 1]**2).sqrt()
115
+ adjusted_sigma = (sigmas[i + 1]**2 - ddim_sigma**2).sqrt()
116
+
117
+ # Recombine the predicted noise and predicted denoised image in the
118
+ # correct proportions for the next step
119
+ x = pred * alphas[i + 1] + eps * adjusted_sigma
120
+ # Add the correct amount of fresh noise
121
+ if eta:
122
+ x += torch.randn_like(x) * ddim_sigma
123
+
124
+ # If we are on the last timestep, output the denoised image
125
+ return pred.to(origin_dtype)
126
+
127
+ # Soft mask inpainting is just shrinking hard (binary) mask inpainting
128
+ # Given a float-valued soft mask (values between 0 and 1), get the binary mask for this particular step
129
+ def get_bmask(i, steps, mask):
130
+ strength = (i+1)/(steps)
131
+ # convert to binary mask
132
+ bmask = torch.where(mask<=strength,1,0)
133
+ return bmask
134
+
135
+ def make_cond_model_fn(model, cond_fn):
136
+ def cond_model_fn(x, sigma, **kwargs):
137
+ with torch.enable_grad():
138
+ x = x.detach().requires_grad_()
139
+ denoised = model(x, sigma, **kwargs)
140
+ cond_grad = cond_fn(x, sigma, denoised=denoised, **kwargs).detach()
141
+ cond_denoised = denoised.detach() + cond_grad * K.utils.append_dims(sigma**2, x.ndim)
142
+ return cond_denoised
143
+ return cond_model_fn
144
+
145
+ # Uses k-diffusion from https://github.com/crowsonkb/k-diffusion
146
+ # init_data is init_audio as latents (if this is latent diffusion)
147
+ # For sampling, set both init_data and mask to None
148
+ # For variations, set init_data
149
+ # For inpainting, set both init_data & mask
150
+ def sample_k(
151
+ model_fn,
152
+ noise,
153
+ init_data=None,
154
+ mask=None,
155
+ steps=100,
156
+ sampler_type="dpmpp-2m-sde",
157
+ sigma_min=0.5,
158
+ sigma_max=50,
159
+ rho=1.0, device="cuda",
160
+ callback=None,
161
+ cond_fn=None,
162
+ **extra_args
163
+ ):
164
+
165
+ denoiser = K.external.VDenoiser(model_fn)
166
+
167
+ if cond_fn is not None:
168
+ denoiser = make_cond_model_fn(denoiser, cond_fn)
169
+
170
+ # Make the list of sigmas. Sigma values are scalars related to the amount of noise each denoising step has
171
+ sigmas = K.sampling.get_sigmas_polyexponential(steps, sigma_min, sigma_max, rho, device=device)
172
+ # Scale the initial noise by sigma
173
+ noise = noise * sigmas[0]
174
+
175
+ wrapped_callback = callback
176
+
177
+ if mask is None and init_data is not None:
178
+ # VARIATION (no inpainting)
179
+ # set the initial latent to the init_data, and noise it with initial sigma
180
+ x = init_data + noise
181
+ elif mask is not None and init_data is not None:
182
+ # INPAINTING
183
+ bmask = get_bmask(0, steps, mask)
184
+ # initial noising
185
+ input_noised = init_data + noise
186
+ # set the initial latent to a mix of init_data and noise, based on step 0's binary mask
187
+ x = input_noised * bmask + noise * (1-bmask)
188
+ # define the inpainting callback function (Note: side effects, it mutates x)
189
+ # See https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/sampling.py#L596C13-L596C105
190
+ # callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
191
+ # This is called immediately after `denoised = model(x, sigmas[i] * s_in, **extra_args)`
192
+ def inpainting_callback(args):
193
+ i = args["i"]
194
+ x = args["x"]
195
+ sigma = args["sigma"]
196
+ #denoised = args["denoised"]
197
+ # noise the init_data input with this step's appropriate amount of noise
198
+ input_noised = init_data + torch.randn_like(init_data) * sigma
199
+ # shrinking hard mask
200
+ bmask = get_bmask(i, steps, mask)
201
+ # mix input_noise with x, using binary mask
202
+ new_x = input_noised * bmask + x * (1-bmask)
203
+ # mutate x
204
+ x[:,:,:] = new_x[:,:,:]
205
+ # wrap together the inpainting callback and the user-submitted callback.
206
+ if callback is None:
207
+ wrapped_callback = inpainting_callback
208
+ else:
209
+ wrapped_callback = lambda args: (inpainting_callback(args), callback(args))
210
+ else:
211
+ # SAMPLING
212
+ # set the initial latent to noise
213
+ x = noise
214
+
215
+
216
+ with torch.cuda.amp.autocast():
217
+ if sampler_type == "k-heun":
218
+ return K.sampling.sample_heun(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args)
219
+ elif sampler_type == "k-lms":
220
+ return K.sampling.sample_lms(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args)
221
+ elif sampler_type == "k-dpmpp-2s-ancestral":
222
+ return K.sampling.sample_dpmpp_2s_ancestral(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args)
223
+ elif sampler_type == "k-dpm-2":
224
+ return K.sampling.sample_dpm_2(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args)
225
+ elif sampler_type == "k-dpm-fast":
226
+ return K.sampling.sample_dpm_fast(denoiser, x, sigma_min, sigma_max, steps, disable=False, callback=wrapped_callback, extra_args=extra_args)
227
+ elif sampler_type == "k-dpm-adaptive":
228
+ return K.sampling.sample_dpm_adaptive(denoiser, x, sigma_min, sigma_max, rtol=0.01, atol=0.01, disable=False, callback=wrapped_callback, extra_args=extra_args)
229
+ elif sampler_type == "dpmpp-2m-sde":
230
+ return K.sampling.sample_dpmpp_2m_sde(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args)
231
+ elif sampler_type == "dpmpp-3m-sde":
232
+ return K.sampling.sample_dpmpp_3m_sde(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args)
233
+
234
+ # Uses discrete Euler sampling for rectified flow models
235
+ # init_data is init_audio as latents (if this is latent diffusion)
236
+ # For sampling, set both init_data and mask to None
237
+ # For variations, set init_data
238
+ # For inpainting, set both init_data & mask
239
+ def sample_rf(
240
+ model_fn,
241
+ noise,
242
+ init_data=None,
243
+ steps=100,
244
+ sigma_max=1,
245
+ device="cuda",
246
+ callback=None,
247
+ cond_fn=None,
248
+ **extra_args
249
+ ):
250
+
251
+ if sigma_max > 1:
252
+ sigma_max = 1
253
+
254
+ if cond_fn is not None:
255
+ denoiser = make_cond_model_fn(denoiser, cond_fn)
256
+
257
+ wrapped_callback = callback
258
+
259
+ if init_data is not None:
260
+ # VARIATION (no inpainting)
261
+ # Interpolate the init data and the noise for init audio
262
+ x = init_data * (1 - sigma_max) + noise * sigma_max
263
+ else:
264
+ # SAMPLING
265
+ # set the initial latent to noise
266
+ x = noise
267
+
268
+ with torch.cuda.amp.autocast():
269
+ # TODO: Add callback support
270
+ #return sample_discrete_euler(model_fn, x, steps, sigma_max, callback=wrapped_callback, **extra_args)
271
+ return sample_discrete_euler(model_fn, x, steps, sigma_max, **extra_args)
SongBloom/models/musicldm/musicldm_dit.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from contextlib import contextmanager
2
+ from dataclasses import dataclass
3
+ from functools import partial
4
+ import logging
5
+ import math
6
+
7
+ import torch
8
+ from torch import nn
9
+
10
+
11
+
12
+ class FourierFeatures(nn.Module):
13
+ def __init__(self, in_features, out_features, std=1.):
14
+ super().__init__()
15
+ assert out_features % 2 == 0
16
+ self.weight = nn.Parameter(torch.randn(
17
+ [out_features // 2, in_features]) * std)
18
+
19
+ def forward(self, input):
20
+ f = 2 * math.pi * input @ self.weight.T
21
+ return torch.cat([f.cos(), f.sin()], dim=-1)
22
+
23
+
24
+
SongBloom/models/songbloom/songbloom_mvsa.py ADDED
@@ -0,0 +1,572 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from contextlib import contextmanager
2
+ from dataclasses import dataclass
3
+ from functools import partial
4
+ import logging
5
+ import math
6
+ import typing as tp
7
+
8
+ import torch
9
+ from torch import nn
10
+ import torch.nn.functional as F
11
+ from einops.layers.torch import Rearrange
12
+ from einops import rearrange
13
+ import tqdm
14
+
15
+ from ..base.utils import create_norm_fn
16
+ from ..base.sample import sample_top_k, sample_top_p, multinomial
17
+ from ..musicgen.modules.streaming import StreamingModule
18
+ from ..musicgen.conditioners import (
19
+ get_condition_fuser,
20
+ get_conditioner_provider,
21
+ ConditionType,
22
+ ConditioningProvider,
23
+ ConditionFuser,
24
+ AttributeDropout,
25
+ ClassifierFreeGuidanceDropout,
26
+ ConditioningAttributes,
27
+ WavCondition,
28
+ JointEmbedCondition
29
+ )
30
+
31
+ from ..musicgen.get_backend import get_backend
32
+
33
+ from ..transformer import ContinuousTransformer as DiT_block
34
+ from ..musicldm.musicldm_dit import FourierFeatures
35
+ from ..musicldm.inference.sampling import get_alphas_sigmas, sample, sample_discrete_euler, sample_discrete_euler_with_temperature
36
+
37
+ ConditionTensors = tp.Dict[str, ConditionType]
38
+ CFGConditions = tp.Union[ConditionTensors, tp.Tuple[ConditionTensors, ConditionTensors]]
39
+
40
+ @dataclass
41
+ class DiTAROutput:
42
+ ar_logit: torch.Tensor
43
+ ar_target: torch.Tensor
44
+ nar_pred: torch.Tensor
45
+ nar_target: torch.Tensor
46
+ nar_t: torch.Tensor
47
+
48
+
49
+
50
+ class MVSA_DiTAR(StreamingModule):
51
+ """
52
+ Multiple skeleton embedding, single compressed vae latent
53
+ eg. V1 V2 V3 A1-3 V4 V5 V6 A4-6
54
+ V -> cross entropy (skeleton)
55
+ A -> local-DiT uncompress -> (A1-3 -> E1 E2 E3)
56
+
57
+ Args:
58
+ StreamingModule (_type_): _description_
59
+ """
60
+
61
+ def __init__(self, condition_provider_cfg, fuser_cfg,
62
+ block_size: int = 32, dim: int = 1024, num_heads: int = 8,
63
+ num_pitch: int = 128, hidden_scale: int = 4, lm_layers: int = 16,
64
+ norm: str = 'layer_norm', pre_norm: bool = False,
65
+ backend='llama',init_std: float=0.02,
66
+ # ======================
67
+ latent_dim: int = 64, diff_layers: int = 8,
68
+ time_cond_type: tp.Literal['adaLM', "prepend"] = "prepend",
69
+ timestep_features_dim: int = 256,
70
+ diffusion_objective: tp.Literal["v", "rectified_flow"] = "v",
71
+ timestep_sampler: tp.Literal["uniform", "logit_normal"] = "uniform",
72
+ rotary_base_val=10000, h_dropout: float = None,
73
+ # ======================
74
+ cfg_dropout: float = 0, cfg_coef: float = 1.0,
75
+ attribute_dropout: tp.Dict[str, tp.Dict[str, float]] = {}
76
+ ):
77
+ super().__init__()
78
+
79
+ self.condition_provider = get_conditioner_provider(condition_provider_cfg)
80
+ self.fuser = get_condition_fuser(fuser_cfg)
81
+
82
+ self.dim = dim
83
+ self.latent_dim = latent_dim
84
+ self.block_size = block_size
85
+
86
+ self.cfg_coef = cfg_coef
87
+ self.h_dropout = h_dropout if h_dropout is not None else 0.
88
+ self.cfg_dropout = ClassifierFreeGuidanceDropout(p=cfg_dropout)
89
+ self.att_dropout = AttributeDropout(p=attribute_dropout)
90
+
91
+
92
+ # Build AR lm
93
+ self.num_pitch = num_pitch + 1 # self.num_pitch = <EOS>, self.num_pitch+1 = special
94
+ self.skeleton_emb = nn.Embedding(self.num_pitch + 1, dim)
95
+ self.bos_token = nn.Parameter(torch.empty(dim).normal_(mean=0.0, std=init_std), requires_grad=True)
96
+
97
+
98
+ # self.lm_type = lm_type
99
+ self.backend = backend
100
+
101
+ if self.backend == 'llama':
102
+ self.ar_transformer = get_backend('llama',
103
+ dim, num_heads, lm_layers, hidden_scale,init_std=init_std, rope_theta=rotary_base_val)
104
+ self.ar_transformer.gradient_checkpointing_enable()
105
+ elif self.backend == 'bart':
106
+ self.cross_encoder = get_backend('bart_enc',
107
+ dim, num_heads, lm_layers // 4, hidden_scale,init_std=init_std)
108
+ self.ar_transformer = get_backend('bart_dec',
109
+ dim, num_heads, lm_layers, hidden_scale,init_std=init_std)
110
+ else:
111
+ raise NotImplementedError(f"Illegal backend: {self.backend}!")
112
+
113
+
114
+ self.skeleton_classifier = nn.Sequential(nn.Linear(dim, dim, bias=False),
115
+ nn.SiLU(),
116
+ nn.Linear(dim, self.num_pitch),)
117
+
118
+ self.pre_norm: tp.Optional[nn.Module] = None
119
+ if pre_norm:
120
+ self.pre_norm = create_norm_fn(norm, dim)
121
+ self.reset_streaming()
122
+
123
+ # Build NAR DiT
124
+ self.block_conv = nn.Sequential(
125
+ Rearrange("b d (n s) -> b n (s d)", s=self.block_size),
126
+ nn.Linear(self.block_size * latent_dim, dim),
127
+ nn.SiLU(),
128
+ nn.Linear(dim, dim)
129
+ )
130
+ self.project_in = nn.Linear(latent_dim, dim) if latent_dim != dim else nn.Identity()
131
+ self.project_out = nn.Linear(dim, latent_dim) if latent_dim != dim else nn.Identity()
132
+
133
+ self.timestep_features_dim = timestep_features_dim
134
+ self.time_cond_type = time_cond_type
135
+ assert self.time_cond_type in ['adaLN', "prepend"]
136
+ self.timestep_features = FourierFeatures(1, timestep_features_dim)
137
+ self.to_timestep_embed = nn.Sequential(
138
+ nn.Linear(timestep_features_dim, dim, bias=False),
139
+ nn.SiLU(),
140
+ nn.Linear(dim, dim),
141
+ )
142
+
143
+ self.time_cond_type = time_cond_type
144
+ self.nar_dit = DiT_block(
145
+ dim=dim,
146
+ depth=diff_layers,
147
+ dim_heads= dim // num_heads,
148
+ rotary_pos_emb=True,
149
+ cross_attend=False,
150
+ causal=False,
151
+ ff_kwargs={"dim_ff": dim * hidden_scale, "no_bias": True},
152
+ global_cond_dim=self.dim if self.time_cond_type=="adaLN" else None,
153
+ rotary_base_val = rotary_base_val,
154
+ # init_std=init_std
155
+ )
156
+ self.nar_dit.gradient_checkpointing_enable()
157
+
158
+ self.diffusion_objective = diffusion_objective
159
+ self.timestep_sampler = timestep_sampler
160
+ self.rng = torch.quasirandom.SobolEngine(1, scramble=True)
161
+
162
+ self.init_weights(init_std=init_std)
163
+
164
+
165
+
166
+
167
+ @property
168
+ def special_token_id(self) -> int:
169
+ return self.num_pitch
170
+
171
+
172
+ @property
173
+ def eos_token_id(self) -> int:
174
+ return self.num_pitch-1
175
+
176
+ def forward(self, x_sketch, x_latent, x_len, condition_tensors) -> DiTAROutput:
177
+ '''
178
+ only for train: lm_forward + diffusion_forward (random_t)
179
+ x_sketch: (B,T) # T % block_sz == 0 (no <eos> token) padded with <eos>
180
+ x_latent: (B, D_{in}, T)
181
+ '''
182
+ # AR
183
+ assert torch.all(x_len % self.block_size == 0), f"{x_len}"
184
+ block_num = x_len // self.block_size
185
+
186
+ sketch_emb = self.skeleton_emb(x_sketch)
187
+ latent_emb = self.block_conv(x_latent)
188
+
189
+ B, T, D = sketch_emb.shape
190
+
191
+ lm_input = rearrange(torch.cat([rearrange(sketch_emb, "b (n s) d -> b n s d", s=self.block_size),
192
+ latent_emb.unsqueeze(dim=2)], dim=2), "b n s d -> b (n s) d")
193
+ lm_input = torch.cat([self.bos_token.reshape(1,1,-1).expand(B,-1,-1),
194
+ lm_input], dim=1) #add <sos>
195
+
196
+ new_seq_len = x_len + block_num + 1
197
+
198
+ ar_target = F.pad(x_sketch, (0,1), value=self.eos_token_id)
199
+ for b,l in enumerate(x_len):
200
+ ar_target[b, l+1:] = self.special_token_id # 用来mask掉多余的eos
201
+
202
+
203
+
204
+ lm_out = self.lm_forward(lm_input, condition_tensors)
205
+
206
+
207
+ indices = torch.arange(lm_out.shape[1])
208
+ h_ind = indices[(indices+1) % (self.block_size+1) == 0]
209
+ not_h_ind = indices[(indices+1) % (self.block_size+1) != 0]
210
+
211
+ x_sketch_logit = self.skeleton_classifier(lm_out[:, not_h_ind])
212
+
213
+ # NAR (h + prev_block)
214
+ h_pad = lm_out[:, h_ind] # B, N, D
215
+ h = torch.cat([hh[:hl] for hh, hl in zip(h_pad, block_num)], dim=0)
216
+ block_semantic = rearrange(sketch_emb, "b (n s) d -> b n s d", s=self.block_size) # B, N, 32, D
217
+ current_block_semantic = torch.cat([bb[:bl] for bb, bl in zip(block_semantic, block_num)], dim=0)
218
+
219
+ if self.training: # for CFG
220
+ drop_h_idx = torch.rand((h.shape[0], 1), device=h.device) < self.h_dropout
221
+ h = torch.masked_fill(h, drop_h_idx, 0)
222
+ # current_block_semantic = torch.masked_fill(current_block_semantic, drop_h_idx.unsqueeze(-1), 0)
223
+
224
+ drop_s_idx = torch.rand((current_block_semantic.shape[0], 1), device=current_block_semantic.device) < self.h_dropout
225
+ current_block_semantic = torch.masked_fill(current_block_semantic, drop_s_idx.unsqueeze(-1), 0)
226
+
227
+ with torch.no_grad():
228
+ block_latent = rearrange(x_latent, "b d (n s) -> b n s d", s=self.block_size) # B, N, 32, D
229
+ current_block = torch.cat([bb[:bl] for bb, bl in zip(block_latent, block_num)], dim=0)
230
+ prev_block = torch.cat([bb[:bl] for bb, bl in zip(F.pad(block_latent, (0,0,0,0,1,0)), block_num)], dim=0)
231
+
232
+ # b_indices = torch.randperm(block_latent.shape[0])[:B*16]
233
+ # h, current_block, prev_block = h[b_indices], current_block[b_indices], prev_block[b_indices]
234
+
235
+ orig_type = x_latent.dtype
236
+ with torch.cuda.amp.autocast(enabled=False):
237
+ if self.timestep_sampler == "uniform":
238
+ # Draw uniformly distributed continuous timesteps
239
+ t = self.rng.draw(h.shape[0])[:, 0].to(device=h.device, dtype=h.dtype)
240
+ elif self.timestep_sampler == "logit_normal":
241
+ t = torch.sigmoid(torch.randn(h.shape[0], device=h.device, dtype=h.dtype))
242
+ elif self.timestep_sampler == "trunc_logit_normal":
243
+ # Draw from logistic truncated normal distribution
244
+ from ..musicldm.musicldm_pl import truncated_logistic_normal_rescaled
245
+ t = truncated_logistic_normal_rescaled(h.shape[0]).to(h.device)
246
+ # Flip the distribution
247
+ t = 1 - t
248
+
249
+ # Calculate the noise schedule parameters for those timesteps
250
+ if self.diffusion_objective == "v":
251
+ alphas, sigmas = get_alphas_sigmas(t)
252
+ elif self.diffusion_objective == "rectified_flow":
253
+ alphas, sigmas = 1-t, t
254
+ # Combine the ground truth data and the noise
255
+ alphas = alphas[:, None, None]
256
+ sigmas = sigmas[:, None, None]
257
+ noise = torch.randn_like(current_block)
258
+ noised_inputs = current_block * alphas + noise * sigmas
259
+ if self.diffusion_objective == "v": # (a_t - a_{t-1})x_0 + (b_t-b_{t-1}) e = -b x_0 + a e
260
+ targets = noise * alphas - current_block * sigmas
261
+ elif self.diffusion_objective == "rectified_flow": #||(XT-X0) - p(x_t, t)||
262
+ targets = noise - current_block
263
+
264
+ nar_output = self.diffusion_forward(noised_inputs.to(orig_type), t.to(orig_type), h, current_block_semantic, prev_block)
265
+
266
+ return DiTAROutput(
267
+ ar_logit=x_sketch_logit,
268
+ ar_target=ar_target,
269
+ nar_pred=nar_output,
270
+ nar_target=targets.to(orig_type),
271
+ nar_t=t
272
+ )
273
+
274
+
275
+
276
+ def lm_forward(self, sequence, condition_tensors: tp.Optional[ConditionTensors] = None) -> torch.Tensor:
277
+ # import pdb; pdb.set_trace()
278
+ B, T, D = sequence.shape
279
+ if self.pre_norm:
280
+ sequence = self.pre_norm(sequence.to(self.pre_norm.weight.data.dtype))
281
+
282
+ input_, cross_attention_input = self.fuser(sequence, condition_tensors)
283
+
284
+ transformer_input = {
285
+ "inputs_embeds":input_,
286
+ "use_cache": self._is_streaming,
287
+ "past_key_values": self._streaming_state.get('past_key_values', None),
288
+ }
289
+ if self.backend == 'bart': # TODO infer 的时候这个玩意不用重复算
290
+ # TODO attention_mask
291
+ cross_attention_input = self.cross_encoder(inputs_embeds=cross_attention_input)
292
+ transformer_input["encoder_hidden_states"] = cross_attention_input.last_hidden_state
293
+
294
+ output = self.ar_transformer(**transformer_input)
295
+ if self._is_streaming:
296
+ self._streaming_state['past_key_values'] = output.past_key_values
297
+ out = output.last_hidden_state
298
+
299
+
300
+
301
+ if len(self.fuser.fuse2cond['prepend']) > 0:
302
+ out = out[:, -T:, :]
303
+
304
+ return out
305
+
306
+
307
+
308
+
309
+ def diffusion_forward(self,
310
+ x: torch.Tensor,
311
+ t: torch.Tensor, # B,
312
+ h: torch.Tensor,
313
+ s: torch.Tensor, # B, self.block_size, D
314
+ history_x: torch.Tensor,
315
+ cfg_coef: float = None) -> torch.Tensor:
316
+
317
+ if cfg_coef is not None:
318
+ # only for infer
319
+ assert not self.training # only for inference
320
+ x = torch.cat([x,x], dim=0)
321
+ t = torch.cat([t,t], dim=0)
322
+ h = torch.cat([h,torch.zeros_like(h)], dim=0)
323
+ s = torch.cat([s,torch.zeros_like(s)], dim=0)
324
+ history_x = torch.cat([history_x,history_x], dim=0)
325
+
326
+ B, T, _ = x.shape
327
+
328
+ input_ = self.project_in(torch.cat([history_x, x], dim=1))
329
+ # print(h.shape, s.shape, input_.shape)
330
+ input_ = torch.cat([h.unsqueeze(1), s, input_], dim=1)
331
+ # Get the batch of timestep embeddings
332
+ timestep_embed = self.to_timestep_embed(self.timestep_features(t[:, None]))# (b, embed_dim)
333
+ # breakpoint()
334
+ if self.time_cond_type == "prepend":
335
+ input_ = torch.cat([timestep_embed.unsqueeze(1), input_], dim=1)
336
+
337
+ transformer_input = {
338
+ "x": input_,
339
+ "global_cond": timestep_embed if self.time_cond_type == "adaLN" else None}
340
+
341
+ output = self.nar_dit(**transformer_input)
342
+
343
+ # remove the prefix from the model outputs
344
+ output = output[:, -T:, :]
345
+ output = self.project_out(output)
346
+
347
+ if cfg_coef is not None:
348
+ cond_output, uncond_output = torch.chunk(output, 2, dim=0)
349
+ output = uncond_output + (cond_output - uncond_output) * cfg_coef
350
+
351
+ return output # [B, T, D]
352
+
353
+
354
+
355
+ def _sample_next_block(self,
356
+ sequence: torch.Tensor,
357
+ prev_latents: torch.Tensor,
358
+ condition_tensors: tp.Optional[ConditionTensors] = None,
359
+ cfg_coef: tp.Optional[tp.Union[float, tp.List[float]]] = None,
360
+ steps: int = 50,
361
+ dit_cfg_type: str = 'h',
362
+ use_sampling: bool = False,
363
+ temp: float = 1.0,
364
+ diff_temp: float = 1.0,
365
+ top_k: int = 0,
366
+ top_p: float = 0.0,
367
+ penalty_token_pool: tp.Optional[list] = None) -> torch.Tensor:
368
+ # infer: lm next_token -> (if % block_sz == 0) infer diff
369
+ # 1. sample sketch (lm) -> 2. sample latent (lm+diff)
370
+ sequence = sequence.clone()
371
+
372
+ if isinstance(cfg_coef, tp.Iterable):
373
+ assert len(cfg_coef) == 2
374
+ cfg_coef_lm, cfg_coef_diff = cfg_coef
375
+ else:
376
+ cfg_coef_lm, cfg_coef_diff = cfg_coef, cfg_coef
377
+
378
+ B = sequence.shape[0]
379
+ # import pdb; pdb.set_trace()
380
+
381
+ if condition_tensors:
382
+ # Preparing for CFG, predicting both conditional and unconditional logits.
383
+ sequence = torch.cat([sequence, sequence], dim=0)
384
+
385
+
386
+ # ############### decode sketch #########################
387
+ next_tokens = []
388
+ next_token_embs = []
389
+
390
+ for k in range(self.block_size):
391
+ if self._is_streaming and k > 0:
392
+ lm_inp = sequence[:,-1:]
393
+ else:
394
+ lm_inp = sequence
395
+
396
+ lm_out = self.lm_forward(
397
+ lm_inp,
398
+ condition_tensors=condition_tensors)
399
+ next_pitch_logit = self.skeleton_classifier(lm_out[:, -1:]) # B, 1, card
400
+
401
+ if condition_tensors:
402
+ cond_logit, uncond_logit = next_pitch_logit.split(B, dim=0)
403
+ next_pitch_logit = uncond_logit + (cond_logit - uncond_logit) * cfg_coef_lm
404
+
405
+ # add penalty to pre-sampled tokens
406
+ if penalty_token_pool is not None and len(penalty_token_pool) > 0: # B, T
407
+ for b in range(B):
408
+ # q_count = torch.bincount(penalty_token_pool)
409
+ q_count = torch.bincount(torch.unique(penalty_token_pool[b]))
410
+ tmp = min(q_count.shape[-1], self.num_pitch - 1)
411
+ next_pitch_logit[b, -1, :tmp] /= (1.1 ** q_count[:tmp])
412
+
413
+ # sample k
414
+ if use_sampling and temp > 0.0:
415
+ probs = torch.softmax(next_pitch_logit / temp, dim=-1)
416
+ if top_p > 0.0:
417
+ next_token = sample_top_p(probs, p=top_p)
418
+ elif top_k > 0:
419
+ next_token = sample_top_k(probs, k=top_k)
420
+ else:
421
+ next_token = multinomial(probs, num_samples=1)
422
+ next_token = next_token.squeeze(-1)
423
+ else:
424
+ next_token = torch.argmax(next_pitch_logit, dim=-1) # B, 1
425
+ if penalty_token_pool is not None and len(penalty_token_pool) > 0: # B, T
426
+ penalty_token_pool = torch.cat([penalty_token_pool, next_token], dim=-1)[:,1:]
427
+ next_token_emb = self.skeleton_emb(next_token) #B, 1, d
428
+ next_tokens.append(next_token)
429
+ next_token_embs.append(next_token_emb)
430
+
431
+ if condition_tensors:
432
+ doubled_next_emb = torch.cat([next_token_emb, next_token_emb], dim=0)
433
+ sequence = torch.cat([sequence, doubled_next_emb], dim=1)
434
+ else:
435
+ sequence = torch.cat([sequence, next_token_emb], dim=1)
436
+
437
+ next_tokens = torch.cat(next_tokens, dim=1)
438
+ next_token_embs = torch.cat(next_token_embs, dim=1)
439
+
440
+ # ############### decode latent ###########################
441
+ # 这里求h虽然double了 但是没用classifier-free guidance
442
+ if self._is_streaming:
443
+ lm_inp = sequence[:,-1:]
444
+ else:
445
+ lm_inp = sequence
446
+
447
+ lm_out = self.lm_forward(
448
+ lm_inp,
449
+ condition_tensors=condition_tensors)
450
+
451
+ h = lm_out[:,-1]
452
+
453
+ noise = torch.randn((B, self.block_size, self.latent_dim), device=h.device, dtype=h.dtype)
454
+
455
+ assert dit_cfg_type in ['h', 'global', 'none']
456
+ """
457
+ global: same cfg setting as next-token-prediction
458
+ none: no cfg
459
+ h: no cfg during ar-stage and apply cfg via ar output
460
+ """
461
+ if condition_tensors:
462
+ if dit_cfg_type == 'global':
463
+ noise = torch.cat([noise, noise], dim=0)
464
+ prev_latents = torch.cat([prev_latents, prev_latents], dim=0)
465
+ semantic_embs = torch.cat([next_token_embs, next_token_embs], dim=0)
466
+ else:
467
+ h, _ = h.chunk(2, dim=0)
468
+ semantic_embs = next_token_embs
469
+
470
+
471
+ if self.diffusion_objective == "v":
472
+ next_latent = sample(self.diffusion_forward, noise, steps=steps, eta=0, h=h, s=semantic_embs, history_x=prev_latents,
473
+ cfg_coef=(cfg_coef_diff if dit_cfg_type=='h' else None))
474
+ elif self.diffusion_objective == "rectified_flow":
475
+ # next_latent = sample_discrete_euler(self.diffusion_forward, noise, steps=steps, h=h, s=semantic_embs, history_x=prev_latents,
476
+ # cfg_coef=(cfg_coef_diff if dit_cfg_type=='h' else None))
477
+ next_latent = sample_discrete_euler_with_temperature(self.diffusion_forward, noise, steps=steps, temperature=diff_temp, h=h, s=semantic_embs, history_x=prev_latents,
478
+ cfg_coef=(cfg_coef_diff if dit_cfg_type=='h' else None))
479
+ if condition_tensors and dit_cfg_type == 'global':
480
+ cond_next_latent, uncond_next_latent = torch.chunk(next_latent, 2, dim=0)
481
+ next_latent = uncond_next_latent + (cond_next_latent - uncond_next_latent) * cfg_coef_diff
482
+
483
+ latent_emb = self.block_conv(next_latent.transpose(1,2))
484
+
485
+ next_block_seq = torch.cat([next_token_embs, latent_emb], dim=1) # B, self.block_size+1, d
486
+
487
+ return next_tokens, next_latent, next_block_seq
488
+
489
+
490
+
491
+ @torch.no_grad()
492
+ def generate(self,
493
+ prompt: tp.Optional[torch.Tensor] = None,
494
+ conditions: tp.List[ConditioningAttributes] = [],
495
+ cfg_coef: tp.Optional[tp.Union[float, tp.List[float]]] = None,
496
+ steps=50,
497
+ dit_cfg_type: str = 'h',
498
+ max_frames: int = 1500, # 60 * 25
499
+ use_sampling: bool = True,
500
+ temp: float = 1.0,
501
+ diff_temp: float = 1.0,
502
+ top_k: int = 0,
503
+ top_p: float = 0.0,
504
+ penalty_repeat: bool = False,
505
+ penalty_window: int = 50) -> torch.Tensor:
506
+ assert not self.training, "generation shouldn't be used in training mode."
507
+
508
+ B = len(conditions)
509
+ assert B==1, "currently do not support batch decoding"
510
+ null_conditions = ClassifierFreeGuidanceDropout(p=1.0)(conditions)
511
+ conditions = conditions + null_conditions
512
+ tokenized = self.condition_provider.tokenize(conditions)
513
+ condition_tensors = self.condition_provider(tokenized)
514
+
515
+
516
+ sequence = self.bos_token.reshape(1,1,-1).expand(B, 1, -1)
517
+ if prompt is not None:
518
+ # TODO
519
+ raise NotImplementedError
520
+ # sequence = torch.cat([sequence, prompt])
521
+
522
+
523
+ prev_blocks = torch.zeros((B, self.block_size, self.latent_dim), device=sequence.device, dtype=sequence.dtype)
524
+ latent_seq, token_seq = None, None
525
+
526
+ with self.streaming():
527
+ prog_bar = tqdm.tqdm()
528
+ while True:
529
+ if token_seq is None or not penalty_repeat:
530
+ penalty_token_pool = None
531
+ else:
532
+ penalty_token_pool = token_seq[: ,-penalty_window:]
533
+ if penalty_token_pool.shape[-1] < penalty_window:
534
+ penalty_token_pool = F.pad(penalty_token_pool, (penalty_window - penalty_token_pool.shape[-1], 0), value=self.eos_token_id)
535
+ next_tokens, next_latent, next_block_seq = self._sample_next_block(sequence[:, -1: ], prev_blocks, condition_tensors,
536
+ cfg_coef=cfg_coef, steps=steps, dit_cfg_type=dit_cfg_type,
537
+ use_sampling=use_sampling, temp=temp, diff_temp=diff_temp,
538
+ top_k=top_k, top_p=top_p,
539
+ penalty_token_pool=penalty_token_pool)
540
+
541
+ if (next_tokens == self.eos_token_id).any() or sequence.shape[1] > max_frames / self.block_size * (self.block_size+1):
542
+ break
543
+
544
+ latent_seq = next_latent if latent_seq is None else torch.cat([latent_seq, next_latent], dim=1) # B,T, D
545
+ token_seq = next_tokens if token_seq is None else torch.cat([token_seq, next_tokens], dim=1) # B,T
546
+ sequence = torch.cat([sequence, next_block_seq], dim=1)
547
+ prev_blocks = next_latent
548
+
549
+ prog_bar.update(self.block_size)
550
+
551
+
552
+ if latent_seq is None:
553
+ latent_seq = prev_blocks
554
+ return latent_seq.transpose(1,2), token_seq
555
+
556
+
557
+
558
+
559
+ def init_weights(self, init_std=0.02):
560
+
561
+ def _init_weights(module, init_std=0.02):
562
+ if isinstance(module, nn.Linear):
563
+ module.weight.data.normal_(mean=0.0, std=init_std)
564
+ # torch.nn.init.xavier_uniform_(module.weight)
565
+ if module.bias is not None:
566
+ nn.init.constant_(module.bias, 0)
567
+ elif isinstance(module, nn.Embedding):
568
+ module.weight.data.normal_(mean=0.0, std=init_std)
569
+ if module.padding_idx is not None:
570
+ module.weight.data[module.padding_idx].zero_()
571
+
572
+ self.apply(partial(_init_weights, init_std=init_std))
SongBloom/models/songbloom/songbloom_pl.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from functools import partial
3
+ import typing as tp
4
+ import torch
5
+ import torch.nn as nn
6
+ from torch.nn import functional as F
7
+ import torchaudio
8
+ import numpy as np
9
+ import random
10
+ from omegaconf import OmegaConf
11
+ import copy
12
+ import lightning as pl
13
+
14
+ import os, sys
15
+
16
+ from ..musicgen.conditioners import WavCondition, JointEmbedCondition, ConditioningAttributes
17
+ from ..vae_frontend import StableVAE
18
+ from .songbloom_mvsa import MVSA_DiTAR
19
+ from ...g2p.lyric_common import key2processor, symbols, LABELS
20
+
21
+
22
+ os.environ['TOKENIZERS_PARALLELISM'] = "false"
23
+
24
+
25
+ class SongBloom_PL(pl.LightningModule):
26
+ def __init__(self, cfg):
27
+ super().__init__()
28
+ # 关闭自动优化
29
+ # self.automatic_optimization = False
30
+
31
+ self.cfg = cfg
32
+
33
+ # Build VAE
34
+ self.vae = StableVAE(**cfg.vae).eval()
35
+ assert self.cfg.model['latent_dim'] == self.vae.channel_dim
36
+
37
+
38
+ self.save_hyperparameters(cfg)
39
+ if self.vae is not None:
40
+ for param in self.vae.parameters():
41
+ param.requires_grad = False
42
+
43
+ # Build DiT
44
+ model_cfg = OmegaConf.to_container(copy.deepcopy(cfg.model), resolve=True)
45
+ for cond_name in model_cfg["condition_provider_cfg"]:
46
+ if model_cfg["condition_provider_cfg"][cond_name]['type'] == 'audio_tokenizer_wrapper':
47
+ model_cfg["condition_provider_cfg"][cond_name]["audio_tokenizer"] = self.vae
48
+ model_cfg["condition_provider_cfg"][cond_name]["cache"] = False
49
+
50
+
51
+ self.model = MVSA_DiTAR(**model_cfg)
52
+ # print(self.model)
53
+
54
+
55
+
56
+
57
+
58
+
59
+ ####################################
60
+
61
+ class SongBloom_Sampler:
62
+
63
+ def __init__(self, compression_model: StableVAE, diffusion: MVSA_DiTAR, lyric_processor_key,
64
+ max_duration: float, prompt_duration: tp.Optional[float] = None):
65
+ self.compression_model = compression_model
66
+ self.diffusion = diffusion
67
+ self.lyric_processor_key = lyric_processor_key
68
+ self.lyric_processor = key2processor.get(lyric_processor_key) if lyric_processor_key is not None else lambda x: x
69
+ # import pdb; pdb.set_trace()
70
+
71
+ assert max_duration is not None
72
+ self.max_duration: float = max_duration
73
+ self.prompt_duration = prompt_duration
74
+
75
+
76
+ self.device = next(iter(diffusion.parameters())).device
77
+ self.generation_params: dict = {}
78
+ # self.set_generation_params(duration=15) # 15 seconds by default
79
+ self.set_generation_params(cfg_coef=1.5, steps=50, dit_cfg_type='h',
80
+ use_sampling=True, top_k=200, max_frames=self.max_duration * 25)
81
+ self._progress_callback: tp.Optional[tp.Callable[[int, int], None]] = None
82
+
83
+ @classmethod
84
+ def build_from_trainer(cls, cfg, strict=True, dtype=torch.float32):
85
+ model_light = SongBloom_PL(cfg)
86
+ incompatible = model_light.load_state_dict(torch.load(cfg.pretrained_path, map_location='cpu'), strict=strict)
87
+
88
+ lyric_processor_key = cfg.train_dataset.lyric_processor
89
+
90
+ print(incompatible)
91
+
92
+ model_light = model_light.eval().cuda().to(dtype=dtype)
93
+ model = cls(
94
+ compression_model = model_light.vae,
95
+ diffusion = model_light.model,
96
+ lyric_processor_key = lyric_processor_key,
97
+ max_duration = cfg.max_dur,
98
+ prompt_duration = cfg.sr * cfg.train_dataset.prompt_len
99
+
100
+ )
101
+ model.set_generation_params(**cfg.inference)
102
+ return model
103
+
104
+ @property
105
+ def frame_rate(self) -> float:
106
+ """Roughly the number of AR steps per seconds."""
107
+ return self.compression_model.frame_rate
108
+
109
+ @property
110
+ def sample_rate(self) -> int:
111
+ """Sample rate of the generated audio."""
112
+ return self.compression_model.sample_rate
113
+
114
+
115
+ def set_generation_params(self, **kwargs):
116
+ """Set the generation parameters."""
117
+ self.generation_params.update(kwargs)
118
+
119
+ # Mulan Inference
120
+ @torch.no_grad()
121
+ def generate(self, lyrics, prompt_wav) -> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, torch.Tensor]]:
122
+ """ Generate samples conditioned on text and melody.
123
+ """
124
+ # breakpoint()
125
+ assert prompt_wav.ndim == 2
126
+ if self.prompt_duration is not None:
127
+ prompt_wav = prompt_wav[..., :self.prompt_duration]
128
+
129
+ attributes, _ = self._prepare_tokens_and_attributes(conditions={"lyrics": [self._process_lyric(lyrics)], "prompt_wav": [prompt_wav]},
130
+ prompt=None, prompt_tokens=None)
131
+
132
+ # breakpoint()
133
+ print(self.generation_params)
134
+ latent_seq, token_seq = self.diffusion.generate(None, attributes, **self.generation_params)
135
+ # print(token_seq)
136
+ audio_recon = self.compression_model.decode(latent_seq).float()
137
+
138
+ return audio_recon
139
+
140
+
141
+ def _process_lyric(self, input_lyric):
142
+ if self.lyric_processor_key == 'pinyin':
143
+ processed_lyric = self.lyric_processor(input_lyric)
144
+ else:
145
+ processed_lyric = []
146
+ check_lyric = input_lyric.split(" ")
147
+ for ii in range(len(check_lyric)):
148
+ if check_lyric[ii] not in symbols and check_lyric[ii] not in LABELS.keys() and len(check_lyric[ii]) > 0:
149
+ new = self.lyric_processor(check_lyric[ii])
150
+ check_lyric[ii] = new
151
+ processed_lyric = " ".join(check_lyric)
152
+
153
+ return processed_lyric
154
+
155
+ @torch.no_grad()
156
+ def _prepare_tokens_and_attributes(
157
+ self,
158
+ conditions: tp.Dict[str, tp.List[tp.Union[str, torch.Tensor]]],
159
+ prompt: tp.Optional[torch.Tensor],
160
+ prompt_tokens: tp.Optional[torch.Tensor] = None,
161
+ ) -> tp.Tuple[tp.List[ConditioningAttributes], tp.Optional[torch.Tensor]]:
162
+ """Prepare model inputs.
163
+
164
+ Args:
165
+ descriptions (list of str): A list of strings used as text conditioning.
166
+ prompt (torch.Tensor): A batch of waveforms used for continuation.
167
+ melody_wavs (torch.Tensor, optional): A batch of waveforms
168
+ used as melody conditioning. Defaults to None.
169
+ """
170
+ batch_size = len(list(conditions.values())[0])
171
+ assert batch_size == 1
172
+ # breakpoint()
173
+ attributes = [ConditioningAttributes() for _ in range(batch_size)]
174
+ for k in self.diffusion.condition_provider.conditioners:
175
+ conds = conditions.pop(k, [None for _ in attributes])
176
+ for attr, cond in zip(attributes, conds):
177
+ if self.diffusion.condition_provider.conditioner_type[k] == 'wav':
178
+ if cond is None:
179
+ attr.wav[k] = WavCondition(
180
+ torch.zeros((1, 1, 1), device=self.device),
181
+ torch.tensor([0], device=self.device).long(),
182
+ sample_rate=[self.sample_rate],
183
+ path=[None])
184
+ else:
185
+ attr.wav[k] = WavCondition(
186
+ cond.to(device=self.device).unsqueeze(0), # 1,C,T .mean(dim=0, keepdim=True)
187
+ torch.tensor([cond.shape[-1]], device=self.device).long(),
188
+ sample_rate=[self.sample_rate],
189
+ path=[None])
190
+ elif self.diffusion.condition_provider.conditioner_type[k] == 'text':
191
+ attr.text[k] = cond
192
+ elif self.diffusion.condition_provider.conditioner_type[k] == 'joint_embed':
193
+ if cond is None or isinstance(cond, str):
194
+ attr.joint_embed[k] = JointEmbedCondition(
195
+ torch.zeros((1, 1, 1), device=self.device),
196
+ [cond],
197
+ torch.tensor([0], device=self.device).long(),
198
+ sample_rate=[self.sample_rate],
199
+ path=[None])
200
+ elif isinstance(cond, torch.Tensor):
201
+ attr.joint_embed[k] = JointEmbedCondition(
202
+ cond.to(device=self.device).mean(dim=0, keepdim=True).unsqueeze(0),
203
+ [None],
204
+ torch.tensor([cond.shape[-1]], device=self.device).long(),
205
+ sample_rate=[self.sample_rate],
206
+ path=[None])
207
+ else:
208
+ raise NotImplementedError
209
+ assert conditions == {}, f"Find illegal conditions: {conditions}, support keys: {self.lm.condition_provider.conditioners}"
210
+ # breakpoint()
211
+ print(attributes)
212
+
213
+ if prompt_tokens is not None:
214
+ prompt_tokens = prompt_tokens.to(self.device)
215
+ assert prompt is None
216
+ elif prompt is not None:
217
+ assert len(attributes) == len(prompt), "Prompt and nb. attributes doesn't match"
218
+ prompt = prompt.to(self.device)
219
+ prompt_tokens = self.compression_model.encode(prompt)
220
+ else:
221
+ prompt_tokens = None
222
+
223
+ return attributes, prompt_tokens
224
+
SongBloom/models/transformer.py ADDED
@@ -0,0 +1,937 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import reduce, partial
2
+ from packaging import version
3
+
4
+ from einops import rearrange, repeat
5
+ from einops.layers.torch import Rearrange
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from torch import nn, einsum
9
+ from torch.cuda.amp import autocast
10
+ from typing import Callable, Literal
11
+ import os, sys
12
+ import warnings
13
+ from torch.utils import checkpoint
14
+ from transformers.utils import is_flash_attn_2_available
15
+
16
+ try:
17
+ assert is_flash_attn_2_available()
18
+ assert torch.cuda.get_device_capability(torch.device("cuda")) >= (8, 0)
19
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
20
+ from flash_attn.bert_padding import index_first_axis, unpad_input, pad_input
21
+ assert os.environ.get("DISABLE_FLASH_ATTN",'0') != "1"
22
+ except Exception as e:
23
+ flash_attn_kvpacked_func = None
24
+ flash_attn_func = None
25
+ warnings.warn("Not support flash-attn!")
26
+
27
+ try:
28
+ import natten
29
+ except ImportError:
30
+ natten = None
31
+
32
+ def checkpoint(function, *args, **kwargs):
33
+ kwargs.setdefault("use_reentrant", False)
34
+ return torch.utils.checkpoint.checkpoint(function, *args, **kwargs)
35
+
36
+
37
+ # Copied and modified from https://github.com/lucidrains/x-transformers/blob/main/x_transformers/attend.py under MIT License
38
+ # License can be found in LICENSES/LICENSE_XTRANSFORMERS.txt
39
+
40
+ def create_causal_mask(i, j, device):
41
+ return torch.ones((i, j), device = device, dtype = torch.bool).triu(j - i + 1)
42
+
43
+ def or_reduce(masks):
44
+ head, *body = masks
45
+ for rest in body:
46
+ head = head | rest
47
+ return head
48
+
49
+ # positional embeddings
50
+
51
+ class AbsolutePositionalEmbedding(nn.Module):
52
+ def __init__(self, dim, max_seq_len):
53
+ super().__init__()
54
+ self.scale = dim ** -0.5
55
+ self.max_seq_len = max_seq_len
56
+ self.emb = nn.Embedding(max_seq_len, dim)
57
+
58
+ def forward(self, x, pos = None, seq_start_pos = None):
59
+ seq_len, device = x.shape[1], x.device
60
+ assert seq_len <= self.max_seq_len, f'you are passing in a sequence length of {seq_len} but your absolute positional embedding has a max sequence length of {self.max_seq_len}'
61
+
62
+ if pos is None:
63
+ pos = torch.arange(seq_len, device = device)
64
+
65
+ if seq_start_pos is not None:
66
+ pos = (pos - seq_start_pos[..., None]).clamp(min = 0)
67
+
68
+ pos_emb = self.emb(pos)
69
+ pos_emb = pos_emb * self.scale
70
+ return pos_emb
71
+
72
+ class ScaledSinusoidalEmbedding(nn.Module):
73
+ def __init__(self, dim, theta = 10000):
74
+ super().__init__()
75
+ assert (dim % 2) == 0, 'dimension must be divisible by 2'
76
+ self.scale = nn.Parameter(torch.ones(1) * dim ** -0.5)
77
+
78
+ half_dim = dim // 2
79
+ freq_seq = torch.arange(half_dim).float() / half_dim
80
+ inv_freq = theta ** -freq_seq
81
+ self.register_buffer('inv_freq', inv_freq, persistent = False)
82
+
83
+ def forward(self, x, pos = None, seq_start_pos = None):
84
+ seq_len, device = x.shape[1], x.device
85
+
86
+ if pos is None:
87
+ pos = torch.arange(seq_len, device = device)
88
+
89
+ if seq_start_pos is not None:
90
+ pos = pos - seq_start_pos[..., None]
91
+
92
+ emb = einsum('i, j -> i j', pos, self.inv_freq)
93
+ emb = torch.cat((emb.sin(), emb.cos()), dim = -1)
94
+ return emb * self.scale
95
+
96
+ class RotaryEmbedding(nn.Module):
97
+ def __init__(
98
+ self,
99
+ dim,
100
+ use_xpos = False,
101
+ scale_base = 512,
102
+ interpolation_factor = 1.,
103
+ base = 10000,
104
+ base_rescale_factor = 1.
105
+ ):
106
+ super().__init__()
107
+ # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
108
+ # has some connection to NTK literature
109
+ # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
110
+ base *= base_rescale_factor ** (dim / (dim - 2))
111
+
112
+ inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim))
113
+ self.register_buffer('inv_freq', inv_freq, persistent=False)
114
+
115
+ assert interpolation_factor >= 1.
116
+ self.interpolation_factor = interpolation_factor
117
+
118
+ if not use_xpos:
119
+ self.register_buffer('scale', None)
120
+ else:
121
+ scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)
122
+
123
+ self.scale_base = scale_base
124
+ self.register_buffer('scale', scale)
125
+
126
+ def forward_from_seq_len(self, seq_len):
127
+ device = self.inv_freq.device
128
+
129
+ t = torch.arange(seq_len, device = device)
130
+ return self.forward(t)
131
+
132
+ @autocast(enabled = False)
133
+ def forward(self, t):
134
+ device = self.inv_freq.device
135
+
136
+ t = t.to(torch.float32)
137
+ seq_len = t.shape[0]
138
+
139
+ t = t / self.interpolation_factor
140
+
141
+ freqs = torch.einsum('i , j -> i j', t, self.inv_freq)
142
+ freqs = torch.cat((freqs, freqs), dim = -1)
143
+
144
+ if self.scale is None:
145
+ return freqs, 1.
146
+
147
+ power = (torch.arange(seq_len, device = device) - (seq_len // 2)) / self.scale_base
148
+ scale = self.scale ** rearrange(power, 'n -> n 1')
149
+ scale = torch.cat((scale, scale), dim = -1)
150
+
151
+ return freqs, scale
152
+
153
+ class RotaryEmbedding2D(RotaryEmbedding):
154
+ def __init__(self, dim, w, **kwargs):
155
+ super().__init__(dim // 2, **kwargs)
156
+ self.w = w
157
+
158
+
159
+ def forward_from_seq_len(self, seq_len):
160
+ device = self.inv_freq.device
161
+ assert seq_len % self.w == 0 , f"{seq_len} % {self.w} != 0"
162
+ h_len = seq_len // self.w
163
+
164
+ t_h = torch.arange(h_len, device = device)
165
+ t_w = torch.arange(self.w, device = device)
166
+
167
+ return self.forward(t_h, t_w)
168
+
169
+ @autocast(enabled = False)
170
+ def forward(self, t_h: torch.Tensor, t_w: torch.Tensor):
171
+ repeat_t_h = t_h.repeat_interleave(t_w.shape[0], dim=0)
172
+ repeat_t_w = t_w.repeat(t_h.shape[0])
173
+ freq_h, scale_h = super().forward(repeat_t_h)
174
+ freq_w, scale_w = super().forward(repeat_t_w)
175
+ freq = torch.stack([freq_h, freq_w], dim=-1) #h*w, D//2, 2
176
+ freq = torch.cat(torch.unbind(freq, dim=-2), dim=-1)
177
+
178
+ if self.scale is None:
179
+ scale = 1.
180
+ else:
181
+ scale = torch.stack([scale_h, scale_w], dim=-1)
182
+ scale = torch.cat(torch.unbind(scale, dim=-2), dim=-1)
183
+
184
+ return freq, scale
185
+
186
+
187
+
188
+
189
+ def rotate_half(x):
190
+ x = rearrange(x, '... (j d) -> ... j d', j = 2)
191
+ x1, x2 = x.unbind(dim = -2)
192
+ return torch.cat((-x2, x1), dim = -1)
193
+
194
+ @autocast(enabled = False)
195
+ def apply_rotary_pos_emb(t, freqs, scale = 1):
196
+ out_dtype = t.dtype
197
+
198
+ # cast to float32 if necessary for numerical stability
199
+ dtype = reduce(torch.promote_types, (t.dtype, freqs.dtype, torch.float32))
200
+ rot_dim, seq_len = freqs.shape[-1], t.shape[-2]
201
+ freqs, t = freqs.to(dtype), t.to(dtype)
202
+ freqs = freqs[-seq_len:, :]
203
+
204
+ if t.ndim == 4 and freqs.ndim == 3:
205
+ freqs = rearrange(freqs, 'b n d -> b 1 n d')
206
+
207
+ # partial rotary embeddings, Wang et al. GPT-J
208
+ t, t_unrotated = t[..., :rot_dim], t[..., rot_dim:]
209
+ t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale)
210
+
211
+ t, t_unrotated = t.to(out_dtype), t_unrotated.to(out_dtype)
212
+ return torch.cat((t, t_unrotated), dim = -1)
213
+
214
+ # norms
215
+ class LayerNorm(nn.Module):
216
+ def __init__(self, dim, bias=False, fix_scale=False):
217
+ """
218
+ bias-less layernorm has been shown to be more stable. most newer models have moved towards rmsnorm, also bias-less
219
+ """
220
+ super().__init__()
221
+
222
+ if fix_scale:
223
+ self.register_buffer("gamma", torch.ones(dim))
224
+ else:
225
+ self.gamma = nn.Parameter(torch.ones(dim))
226
+
227
+ if bias:
228
+ self.beta = nn.Parameter(torch.zeros(dim))
229
+ else:
230
+ self.register_buffer("beta", torch.zeros(dim))
231
+
232
+
233
+ def forward(self, x):
234
+ return F.layer_norm(x, x.shape[-1:], weight=self.gamma, bias=self.beta)
235
+
236
+ # feedforward
237
+
238
+ class GLU(nn.Module):
239
+ def __init__(
240
+ self,
241
+ dim_in,
242
+ dim_out,
243
+ activation: Callable,
244
+ use_conv = False,
245
+ conv_kernel_size = 3,
246
+ bias = False,
247
+ ):
248
+ super().__init__()
249
+ self.act = activation
250
+ self.up_proj = nn.Linear(dim_in, dim_out, bias=bias) if not use_conv else nn.Conv1d(dim_in, dim_out, conv_kernel_size, padding = (conv_kernel_size // 2))
251
+ self.gate_proj = nn.Linear(dim_in, dim_out, bias=bias) if not use_conv else nn.Conv1d(dim_in, dim_out, conv_kernel_size, padding = (conv_kernel_size // 2))
252
+ self.use_conv = use_conv
253
+
254
+ def forward(self, x):
255
+ if self.use_conv:
256
+ x = rearrange(x, 'b n d -> b d n')
257
+ gate = self.gate_proj(x)
258
+ x = self.up_proj(x)
259
+ x = rearrange(x, 'b d n -> b n d')
260
+ gate = rearrange(gate, 'b d n -> b n d')
261
+ else:
262
+ gate = self.gate_proj(x)
263
+ x = self.up_proj(x)
264
+
265
+ return x * self.act(gate)
266
+
267
+ class FeedForward(nn.Module):
268
+ def __init__(
269
+ self,
270
+ dim,
271
+ dim_out = None,
272
+ dim_ff = None,
273
+ no_bias = False,
274
+ glu = True,
275
+ use_conv = False,
276
+ conv_kernel_size = 3,
277
+ zero_init_output = True,
278
+ ):
279
+ super().__init__()
280
+ inner_dim = dim_ff if dim_ff is not None else 4 * dim
281
+
282
+ # Default to SwiGLU
283
+
284
+ activation = nn.SiLU()
285
+
286
+ dim_out = dim if dim_out is None else dim_out
287
+
288
+ if glu:
289
+ linear_in = GLU(dim, inner_dim, activation, bias=not no_bias)
290
+ else:
291
+ linear_in = nn.Sequential(
292
+ Rearrange('b n d -> b d n') if use_conv else nn.Identity(),
293
+ nn.Linear(dim, inner_dim, bias = not no_bias) if not use_conv else nn.Conv1d(dim, inner_dim, conv_kernel_size, padding = (conv_kernel_size // 2), bias = not no_bias),
294
+ Rearrange('b n d -> b d n') if use_conv else nn.Identity(),
295
+ activation
296
+ )
297
+
298
+ linear_out = nn.Linear(inner_dim, dim_out, bias = not no_bias) if not use_conv else nn.Conv1d(inner_dim, dim_out, conv_kernel_size, padding = (conv_kernel_size // 2), bias = not no_bias)
299
+
300
+ # init last linear layer to 0
301
+ if zero_init_output:
302
+ nn.init.zeros_(linear_out.weight)
303
+ if not no_bias:
304
+ nn.init.zeros_(linear_out.bias)
305
+
306
+
307
+ self.ff = nn.Sequential(
308
+ linear_in,
309
+ Rearrange('b d n -> b n d') if use_conv else nn.Identity(),
310
+ linear_out,
311
+ Rearrange('b n d -> b d n') if use_conv else nn.Identity(),
312
+ )
313
+
314
+ def forward(self, x):
315
+ return self.ff(x)
316
+
317
+ class Attention(nn.Module):
318
+ def __init__(
319
+ self,
320
+ dim,
321
+ dim_heads = 64,
322
+ dim_context = None,
323
+ causal = False,
324
+ zero_init_output=True,
325
+ qk_norm: Literal['l2', 'ln', 'none'] = 'none',
326
+ natten_kernel_size = None
327
+ ):
328
+ super().__init__()
329
+ self.dim = dim
330
+ self.dim_heads = dim_heads
331
+ self.causal = causal
332
+
333
+ dim_kv = dim_context if dim_context is not None else dim
334
+
335
+ self.num_heads = dim // dim_heads
336
+ self.kv_heads = dim_kv // dim_heads
337
+
338
+ if dim_context is not None:
339
+ self.to_q = nn.Linear(dim, dim, bias=False)
340
+ self.to_kv = nn.Linear(dim_kv, dim_kv * 2, bias=False)
341
+ else:
342
+ self.to_qkv = nn.Linear(dim, dim * 3, bias=False)
343
+
344
+ self.to_out = nn.Linear(dim, dim, bias=False)
345
+
346
+ if zero_init_output:
347
+ nn.init.zeros_(self.to_out.weight)
348
+
349
+ self.qk_norm = qk_norm
350
+
351
+ if self.qk_norm == "ln":
352
+ self.q_norm = nn.LayerNorm(dim_heads, elementwise_affine=True, eps=1.0e-6)
353
+ self.k_norm = nn.LayerNorm(dim_heads, elementwise_affine=True, eps=1.0e-6)
354
+
355
+ # Using 1d neighborhood attention
356
+ self.natten_kernel_size = natten_kernel_size
357
+ if natten_kernel_size is not None:
358
+ return
359
+
360
+ self.use_pt_flash = torch.cuda.is_available() and version.parse(torch.__version__) >= version.parse('2.0.0')
361
+
362
+ self.use_fa_flash = torch.cuda.is_available() and flash_attn_func is not None
363
+
364
+ self.sdp_kwargs = dict(
365
+ enable_flash = True,
366
+ enable_math = True,
367
+ enable_mem_efficient = True
368
+ )
369
+
370
+ def flash_attn(
371
+ self,
372
+ q,
373
+ k,
374
+ v,
375
+ mask = None,
376
+ causal = None
377
+ ):
378
+ batch, heads, q_len, _, k_len, device = *q.shape, k.shape[-2], q.device
379
+ kv_heads = k.shape[1]
380
+ # Recommended for multi-query single-key-value attention by Tri Dao
381
+ # kv shape torch.Size([1, 512, 64]) -> torch.Size([1, 8, 512, 64])
382
+
383
+ if heads != kv_heads:
384
+ # Repeat interleave kv_heads to match q_heads
385
+ heads_per_kv_head = heads // kv_heads
386
+ k, v = map(lambda t: t.repeat_interleave(heads_per_kv_head, dim = 1), (k, v))
387
+
388
+ if k.ndim == 3:
389
+ k = rearrange(k, 'b ... -> b 1 ...').expand_as(q)
390
+
391
+ if v.ndim == 3:
392
+ v = rearrange(v, 'b ... -> b 1 ...').expand_as(q)
393
+
394
+ causal = self.causal if causal is None else causal
395
+
396
+ if q_len == 1 and causal:
397
+ causal = False
398
+
399
+ if mask is not None:
400
+ assert mask.ndim == 4
401
+ mask = mask.expand(batch, heads, q_len, k_len)
402
+
403
+ # handle kv cache - this should be bypassable in updated flash attention 2
404
+
405
+ if k_len > q_len and causal:
406
+ causal_mask = self.create_causal_mask(q_len, k_len, device = device)
407
+ if mask is None:
408
+ mask = ~causal_mask
409
+ else:
410
+ mask = mask & ~causal_mask
411
+ causal = False
412
+
413
+ # manually handle causal mask, if another mask was given
414
+
415
+ row_is_entirely_masked = None
416
+
417
+ if mask is not None and causal:
418
+ causal_mask = self.create_causal_mask(q_len, k_len, device = device)
419
+ mask = mask & ~causal_mask
420
+
421
+ # protect against an entire row being masked out
422
+
423
+ row_is_entirely_masked = ~mask.any(dim = -1)
424
+ mask[..., 0] = mask[..., 0] | row_is_entirely_masked
425
+
426
+ causal = False
427
+
428
+ with torch.backends.cuda.sdp_kernel(**self.sdp_kwargs):
429
+ out = F.scaled_dot_product_attention(
430
+ q, k, v,
431
+ attn_mask = mask,
432
+ is_causal = causal
433
+ )
434
+
435
+ # for a row that is entirely masked out, should zero out the output of that row token
436
+
437
+ if row_is_entirely_masked is not None:
438
+ out = out.masked_fill(row_is_entirely_masked[..., None], 0.)
439
+
440
+ return out
441
+
442
+ def forward(
443
+ self,
444
+ x,
445
+ context = None,
446
+ mask = None,
447
+ context_mask = None,
448
+ rotary_pos_emb = None,
449
+ causal = None
450
+ ):
451
+ h, kv_h, has_context = self.num_heads, self.kv_heads, context is not None
452
+
453
+ kv_input = context if has_context else x
454
+
455
+ if hasattr(self, 'to_q'):
456
+ # Use separate linear projections for q and k/v
457
+ q = self.to_q(x)
458
+ q = rearrange(q, 'b n (h d) -> b h n d', h = h)
459
+
460
+ k, v = self.to_kv(kv_input).chunk(2, dim=-1)
461
+
462
+ k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = kv_h), (k, v))
463
+ else:
464
+ # Use fused linear projection
465
+ q, k, v = self.to_qkv(x).chunk(3, dim=-1)
466
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
467
+
468
+ # Normalize q and k for cosine sim attention
469
+ if self.qk_norm == "l2":
470
+ q = F.normalize(q, dim=-1)
471
+ k = F.normalize(k, dim=-1)
472
+ elif self.qk_norm == "ln":
473
+ q = self.q_norm(q)
474
+ k = self.k_norm(k)
475
+
476
+ if rotary_pos_emb is not None and not has_context:
477
+ freqs, _ = rotary_pos_emb
478
+
479
+ q_dtype = q.dtype
480
+ k_dtype = k.dtype
481
+
482
+ q = q.to(torch.float32)
483
+ k = k.to(torch.float32)
484
+ freqs = freqs.to(torch.float32)
485
+
486
+ q = apply_rotary_pos_emb(q, freqs)
487
+ k = apply_rotary_pos_emb(k, freqs)
488
+
489
+ q = q.to(q_dtype)
490
+ k = k.to(k_dtype)
491
+
492
+ # TODO 这里这俩都是 [B, k/Q_len]这样的格式
493
+ # context mask也许应该改成 [B, Q_len, K_len]
494
+ # 并且下面flash_attn 默认假设attn靠左部分全为1
495
+ input_mask = context_mask # cross-attn
496
+ if input_mask is None and not has_context: # self-attn
497
+ input_mask = mask
498
+
499
+ # determine masking
500
+ masks = []
501
+ final_attn_mask = None # The mask that will be applied to the attention matrix, taking all masks into account
502
+
503
+ if input_mask is not None:
504
+ input_mask = rearrange(input_mask, 'b j -> b 1 1 j')
505
+ masks.append(~input_mask)
506
+
507
+ # Other masks will be added here later
508
+
509
+ if len(masks) > 0:
510
+ final_attn_mask = ~or_reduce(masks)
511
+
512
+ n, device = q.shape[-2], q.device
513
+
514
+ causal = self.causal if causal is None else causal
515
+ if n == 1 and causal:
516
+ causal = False
517
+ if self.natten_kernel_size is not None:
518
+ if natten is None:
519
+ raise ImportError('natten not installed, please install natten to use neighborhood attention')
520
+
521
+ dtype_in = q.dtype
522
+ q, k, v = map(lambda t: t.to(torch.float32), (q, k, v))
523
+
524
+ attn = natten.functional.natten1dqk(q, k, kernel_size = self.natten_kernel_size, dilation=1)
525
+
526
+ if final_attn_mask is not None:
527
+ attn = attn.masked_fill(final_attn_mask, -torch.finfo(attn.dtype).max)
528
+
529
+ attn = F.softmax(attn, dim=-1, dtype=torch.float32)
530
+
531
+ out = natten.functional.natten1dav(attn, v, kernel_size = self.natten_kernel_size, dilation=1).to(dtype_in)
532
+
533
+ # Prioritize Flash Attention 2
534
+ elif self.use_fa_flash:
535
+ fa_dtype_in = q.dtype
536
+ if q.dtype in [torch.float, torch.float32]:
537
+ target_dtype = self.to_out.weight.dtype if self.to_out.weight.dtype not in [torch.float, torch.float32] else torch.float16
538
+ warnings.warn(
539
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
540
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
541
+ f" {target_dtype}."
542
+ )
543
+ q, k, v = map(lambda t: t.to(target_dtype), (q, k, v))
544
+ q, k, v = map(lambda t: rearrange(t, 'b h n d -> b n h d'), (q, k, v))
545
+ # out = flash_attn_func(q, k, v, causal = causal)
546
+ if final_attn_mask is not None:
547
+ # Check if the mask meets the requirement of FlashAttn
548
+ kv_seq_mask = final_attn_mask.squeeze(dim=[1,2])
549
+ kv_reallens = kv_seq_mask.sum(dim=-1, dtype=torch.int32)
550
+ first_zero_indices = torch.argmax((kv_seq_mask == 0).int(), dim=1).masked_fill(kv_seq_mask[:,-1] != 0, kv_seq_mask.shape[1])
551
+ assert (kv_reallens == first_zero_indices).all(), f'{kv_reallens} , {first_zero_indices}'
552
+
553
+ batch_size, kv_seq_len, num_key_value_heads, head_dim = k.shape
554
+ unpad_k, indices_k, cu_seqlens_k, max_seqlen_in_batch_k = unpad_input(k, kv_seq_mask)
555
+ unpad_v = index_first_axis(
556
+ v.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
557
+ )
558
+ q_seq_len = q.shape[1]
559
+ unpad_q, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(q, torch.ones((batch_size, q_seq_len), device=q.device, dtype=torch.bool))
560
+ # print(q.shape, k.shape)
561
+ # print(cu_seqlens_q, cu_seqlens_k)
562
+ # breakpoint()
563
+ out_unpad = flash_attn_varlen_func(
564
+ unpad_q,
565
+ unpad_k,
566
+ unpad_v,
567
+ cu_seqlens_q=cu_seqlens_q,
568
+ cu_seqlens_k=cu_seqlens_k,
569
+ max_seqlen_q=max_seqlen_in_batch_q,
570
+ max_seqlen_k=max_seqlen_in_batch_k,
571
+ causal=causal,
572
+ )
573
+ out = pad_input(out_unpad, indices_q, batch_size, q_seq_len)
574
+ else:
575
+ out = flash_attn_func(q, k, v, causal = causal)
576
+
577
+
578
+ out = rearrange(out.to(fa_dtype_in), 'b n h d -> b h n d')
579
+ # Fall back to PyTorch implementation
580
+ elif self.use_pt_flash:
581
+ out = self.flash_attn(q, k, v, causal = causal, mask = final_attn_mask)
582
+
583
+ else:
584
+ # Fall back to custom implementation
585
+
586
+ if h != kv_h:
587
+ # Repeat interleave kv_heads to match q_heads
588
+ heads_per_kv_head = h // kv_h
589
+ k, v = map(lambda t: t.repeat_interleave(heads_per_kv_head, dim = 1), (k, v))
590
+
591
+ scale = 1. / (q.shape[-1] ** 0.5)
592
+
593
+ kv_einsum_eq = 'b j d' if k.ndim == 3 else 'b h j d'
594
+
595
+ dots = einsum(f'b h i d, {kv_einsum_eq} -> b h i j', q, k) * scale
596
+
597
+ i, j, dtype = *dots.shape[-2:], dots.dtype
598
+
599
+ mask_value = -torch.finfo(dots.dtype).max
600
+
601
+ if final_attn_mask is not None:
602
+ dots = dots.masked_fill(~final_attn_mask, mask_value)
603
+
604
+ if causal:
605
+ causal_mask = self.create_causal_mask(i, j, device = device)
606
+ dots = dots.masked_fill(causal_mask, mask_value)
607
+
608
+ attn = F.softmax(dots, dim=-1, dtype=torch.float32)
609
+ attn = attn.type(dtype)
610
+
611
+ out = einsum(f'b h i j, {kv_einsum_eq} -> b h i d', attn, v)
612
+
613
+ # merge heads
614
+ out = rearrange(out, ' b h n d -> b n (h d)')
615
+
616
+ # Communicate between heads
617
+
618
+ # with autocast(enabled = False):
619
+ # out_dtype = out.dtype
620
+ # out = out.to(torch.float32)
621
+ # out = self.to_out(out).to(out_dtype)
622
+ out = self.to_out(out)
623
+
624
+ if mask is not None:
625
+ mask = rearrange(mask, 'b n -> b n 1')
626
+ out = out.masked_fill(~mask, 0.)
627
+
628
+ return out
629
+
630
+ class ConformerModule(nn.Module):
631
+ def __init__(
632
+ self,
633
+ dim,
634
+ norm_kwargs = {},
635
+ ):
636
+
637
+ super().__init__()
638
+
639
+ self.dim = dim
640
+
641
+ self.in_norm = LayerNorm(dim, **norm_kwargs)
642
+ self.pointwise_conv = nn.Conv1d(dim, dim, kernel_size=1, bias=False)
643
+ self.glu = GLU(dim, dim, nn.SiLU())
644
+ self.depthwise_conv = nn.Conv1d(dim, dim, kernel_size=17, groups=dim, padding=8, bias=False)
645
+ self.mid_norm = LayerNorm(dim, **norm_kwargs) # This is a batch norm in the original but I don't like batch norm
646
+ self.swish = nn.SiLU()
647
+ self.pointwise_conv_2 = nn.Conv1d(dim, dim, kernel_size=1, bias=False)
648
+
649
+ def forward(self, x):
650
+ x = self.in_norm(x)
651
+ x = rearrange(x, 'b n d -> b d n')
652
+ x = self.pointwise_conv(x)
653
+ x = rearrange(x, 'b d n -> b n d')
654
+ x = self.glu(x)
655
+ x = rearrange(x, 'b n d -> b d n')
656
+ x = self.depthwise_conv(x)
657
+ x = rearrange(x, 'b d n -> b n d')
658
+ x = self.mid_norm(x)
659
+ x = self.swish(x)
660
+ x = rearrange(x, 'b n d -> b d n')
661
+ x = self.pointwise_conv_2(x)
662
+ x = rearrange(x, 'b d n -> b n d')
663
+
664
+ return x
665
+
666
+ class TransformerBlock(nn.Module):
667
+ def __init__(
668
+ self,
669
+ dim,
670
+ dim_heads = 64,
671
+ cross_attend = False,
672
+ dim_context = None,
673
+ global_cond_dim = None,
674
+ causal = False,
675
+ zero_init_branch_outputs = True,
676
+ conformer = False,
677
+ layer_ix = -1,
678
+ remove_norms = False,
679
+ attn_kwargs = {},
680
+ ff_kwargs = {},
681
+ norm_kwargs = {}
682
+ ):
683
+
684
+ super().__init__()
685
+ self.dim = dim
686
+ self.dim_heads = dim_heads
687
+ self.cross_attend = cross_attend
688
+ self.dim_context = dim_context
689
+ self.causal = causal
690
+
691
+ self.pre_norm = LayerNorm(dim, **norm_kwargs) if not remove_norms else nn.Identity()
692
+
693
+ self.self_attn = Attention(
694
+ dim,
695
+ dim_heads = dim_heads,
696
+ causal = causal,
697
+ zero_init_output=zero_init_branch_outputs,
698
+ **attn_kwargs
699
+ )
700
+
701
+ if cross_attend:
702
+ self.cross_attend_norm = LayerNorm(dim, **norm_kwargs) if not remove_norms else nn.Identity()
703
+ self.cross_attn = Attention(
704
+ dim,
705
+ dim_heads = dim_heads,
706
+ dim_context=dim_context,
707
+ causal = causal,
708
+ zero_init_output=zero_init_branch_outputs,
709
+ **attn_kwargs
710
+ )
711
+
712
+ self.ff_norm = LayerNorm(dim, **norm_kwargs) if not remove_norms else nn.Identity()
713
+ self.ff = FeedForward(dim, zero_init_output=zero_init_branch_outputs, **ff_kwargs)
714
+
715
+ self.layer_ix = layer_ix
716
+
717
+ self.conformer = ConformerModule(dim, norm_kwargs=norm_kwargs) if conformer else None
718
+
719
+ self.global_cond_dim = global_cond_dim
720
+
721
+ if global_cond_dim is not None:
722
+ self.to_scale_shift_gate = nn.Sequential(
723
+ nn.SiLU(),
724
+ nn.Linear(global_cond_dim, dim * 6, bias=False)
725
+ )
726
+
727
+ nn.init.zeros_(self.to_scale_shift_gate[1].weight)
728
+ #nn.init.zeros_(self.to_scale_shift_gate_self[1].bias)
729
+
730
+ def forward(
731
+ self,
732
+ x,
733
+ mask = None,
734
+ global_cond=None,
735
+ context = None,
736
+ context_mask = None,
737
+ rotary_pos_emb = None
738
+ ):
739
+ if self.global_cond_dim is not None:
740
+ assert global_cond is not None
741
+ # scale_self, shift_self, gate_self, scale_ff, shift_ff, gate_ff = checkpoint(self.to_scale_shift_gate, global_cond).unsqueeze(1).chunk(6, dim = -1)
742
+ scale_self, shift_self, gate_self, scale_ff, shift_ff, gate_ff = self.to_scale_shift_gate(global_cond).unsqueeze(1).chunk(6, dim = -1)
743
+
744
+ # self-attention with adaLN
745
+ residual = x
746
+ x = self.pre_norm(x)
747
+ x = x * (1 + scale_self) + shift_self
748
+ x = self.self_attn(x, mask = mask, rotary_pos_emb = rotary_pos_emb)
749
+ x = x * torch.sigmoid(1 - gate_self)
750
+ x = x + residual
751
+
752
+ if context is not None:
753
+ x = x + self.cross_attn(self.cross_attend_norm(x), context = context, context_mask = context_mask)
754
+
755
+ if self.conformer is not None:
756
+ x = x + self.conformer(x)
757
+
758
+ # feedforward with adaLN
759
+ residual = x
760
+ x = self.ff_norm(x)
761
+ x = x * (1 + scale_ff) + shift_ff
762
+ x = self.ff(x)
763
+ x = x * torch.sigmoid(1 - gate_ff)
764
+ x = x + residual
765
+
766
+ else:
767
+ assert global_cond is None
768
+ x = x + self.self_attn(self.pre_norm(x), mask = mask, rotary_pos_emb = rotary_pos_emb)
769
+
770
+ if context is not None:
771
+ x = x + self.cross_attn(self.cross_attend_norm(x), context = context, context_mask = context_mask)
772
+
773
+ if self.conformer is not None:
774
+ x = x + self.conformer(x)
775
+
776
+ x = x + self.ff(self.ff_norm(x))
777
+
778
+ return x
779
+
780
+ class ContinuousTransformer(nn.Module):
781
+ def __init__(
782
+ self,
783
+ dim,
784
+ depth,
785
+ *,
786
+ dim_in = None,
787
+ dim_out = None,
788
+ dim_heads = 64,
789
+ cross_attend=False,
790
+ cross_atten_layer_idx=None,
791
+ cond_token_dim=None,
792
+ global_cond_dim=None,
793
+ causal=False,
794
+ rotary_pos_emb=True,
795
+ zero_init_branch_outputs=True,
796
+ conformer=False,
797
+ use_sinusoidal_emb=False,
798
+ use_abs_pos_emb=False,
799
+ abs_pos_emb_max_length=10000,
800
+ pos_emb_2d_size=1,
801
+ rotary_base_val=10000,
802
+ init_std=0.02,
803
+ **kwargs
804
+ ):
805
+
806
+ super().__init__()
807
+
808
+ self.dim = dim
809
+ self.depth = depth
810
+ self.causal = causal
811
+ self.layers = nn.ModuleList([])
812
+
813
+ self.project_in = nn.Linear(dim_in, dim, bias=False) if dim_in is not None else nn.Identity()
814
+ self.project_out = nn.Linear(dim, dim_out, bias=False) if dim_out is not None else nn.Identity()
815
+
816
+ if rotary_pos_emb:
817
+ if pos_emb_2d_size == 1:
818
+ self.rotary_pos_emb = RotaryEmbedding(max(dim_heads // 2, 32), base=rotary_base_val)
819
+ else:
820
+ self.rotary_pos_emb = RotaryEmbedding2D(max(dim_heads // 2, 32), pos_emb_2d_size, base=rotary_base_val)
821
+ else:
822
+ self.rotary_pos_emb = None
823
+
824
+ self.use_sinusoidal_emb = use_sinusoidal_emb
825
+ if use_sinusoidal_emb:
826
+ if pos_emb_2d_size != 1:
827
+ raise NotImplementedError
828
+ self.pos_emb = ScaledSinusoidalEmbedding(dim)
829
+
830
+ self.use_abs_pos_emb = use_abs_pos_emb
831
+ if use_abs_pos_emb:
832
+ if pos_emb_2d_size != 1:
833
+ raise NotImplementedError
834
+ self.pos_emb = AbsolutePositionalEmbedding(dim, abs_pos_emb_max_length)
835
+
836
+
837
+ if cross_atten_layer_idx is None:
838
+ cross_atten_layer_idx = list(range(depth))
839
+ for i in range(depth):
840
+ self.layers.append(
841
+ TransformerBlock(
842
+ dim,
843
+ dim_heads = dim_heads,
844
+ cross_attend = cross_attend and (i in cross_atten_layer_idx),
845
+ dim_context = cond_token_dim,
846
+ global_cond_dim = global_cond_dim,
847
+ causal = causal,
848
+ zero_init_branch_outputs = zero_init_branch_outputs,
849
+ conformer=conformer,
850
+ layer_ix=i,
851
+ **kwargs
852
+ )
853
+ )
854
+ self.gradient_checkpointing = False
855
+
856
+ self.apply(partial(self._init_weights,init_std=init_std))
857
+
858
+ def forward(
859
+ self,
860
+ x,
861
+ mask = None,
862
+ prepend_embeds = None,
863
+ prepend_mask = None,
864
+ global_cond = None,
865
+ return_info = False,
866
+ **kwargs
867
+ ):
868
+ batch, seq, device = *x.shape[:2], x.device
869
+
870
+ info = {
871
+ "hidden_states": [],
872
+ }
873
+
874
+ x = self.project_in(x)
875
+
876
+ if prepend_embeds is not None:
877
+ prepend_length, prepend_dim = prepend_embeds.shape[1:]
878
+
879
+ assert prepend_dim == x.shape[-1], 'prepend dimension must match sequence dimension'
880
+
881
+ x = torch.cat((prepend_embeds, x), dim = -2)
882
+
883
+ if prepend_mask is not None or mask is not None:
884
+ mask = mask if mask is not None else torch.ones((batch, seq), device = device, dtype = torch.bool)
885
+ prepend_mask = prepend_mask if prepend_mask is not None else torch.ones((batch, prepend_length), device = device, dtype = torch.bool)
886
+
887
+ mask = torch.cat((prepend_mask, mask), dim = -1)
888
+
889
+ # Attention layers
890
+
891
+ if self.rotary_pos_emb is not None:
892
+ rotary_pos_emb = self.rotary_pos_emb.forward_from_seq_len(x.shape[1])
893
+ else:
894
+ rotary_pos_emb = None
895
+
896
+ if self.use_sinusoidal_emb or self.use_abs_pos_emb:
897
+ x = x + self.pos_emb(x)
898
+
899
+ # Iterate over the transformer layers
900
+ context, context_mask = kwargs.pop('context', None), kwargs.pop("context_mask", None)
901
+
902
+ for layer_idx, layer in enumerate(self.layers):
903
+ if layer.cross_attend:
904
+ # x = layer(x, mask, global_cond=global_cond, rotary_pos_emb=rotary_pos_emb, context=context, context_mask=context_mask,**kwargs)
905
+ if self.gradient_checkpointing:
906
+ x = checkpoint(layer, x, mask, global_cond, context, context_mask, rotary_pos_emb=rotary_pos_emb, **kwargs)
907
+ else:
908
+ x = layer(x, mask, global_cond, context, context_mask, rotary_pos_emb=rotary_pos_emb, **kwargs)
909
+ else:
910
+ # x = layer(x, mask, global_cond=global_cond, rotary_pos_emb=rotary_pos_emb, **kwargs)
911
+ if self.gradient_checkpointing:
912
+ x = checkpoint(layer, x, mask, global_cond, rotary_pos_emb=rotary_pos_emb, **kwargs)
913
+ else:
914
+ x = layer(x, mask, global_cond, rotary_pos_emb=rotary_pos_emb, **kwargs)
915
+ if return_info:
916
+ info["hidden_states"].append(x)
917
+
918
+ x = self.project_out(x)
919
+
920
+ if return_info:
921
+ return x, info
922
+
923
+ return x
924
+
925
+ def gradient_checkpointing_enable(self):
926
+ self.gradient_checkpointing = True
927
+
928
+
929
+ def _init_weights(self, module, init_std=0.02):
930
+ if isinstance(module, nn.Linear):
931
+ module.weight.data.normal_(mean=0.0, std=init_std)
932
+ if module.bias is not None:
933
+ module.bias.data.zero_()
934
+ elif isinstance(module, nn.Embedding):
935
+ module.weight.data.normal_(mean=0.0, std=init_std)
936
+ if module.padding_idx is not None:
937
+ module.weight.data[module.padding_idx].zero_()
SongBloom/models/vae_frontend/__init__.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import typing as tp
4
+ import torchaudio
5
+ import einops
6
+ from abc import ABC, abstractmethod
7
+
8
+
9
+ class AbstractVAE(ABC, nn.Module):
10
+
11
+ @property
12
+ @abstractmethod
13
+ def frame_rate(self) -> float:
14
+ ...
15
+
16
+ @property
17
+ @abstractmethod
18
+ def orig_sample_rate(self) -> int:
19
+ ...
20
+
21
+
22
+ @property
23
+ @abstractmethod
24
+ def channel_dim(self) -> int:
25
+ ...
26
+
27
+ @property
28
+ @abstractmethod
29
+ def split_bands(self) -> int:
30
+ ...
31
+
32
+ @property
33
+ @abstractmethod
34
+ def input_channel(self) -> int:
35
+ ...
36
+
37
+
38
+ def encode(self, wav) -> torch.Tensor:
39
+ ...
40
+
41
+ def decode(self, latents) -> torch.Tensor:
42
+ ...
43
+
44
+
45
+ from .autoencoders import create_autoencoder_from_config, AudioAutoencoder
46
+ class StableVAE(AbstractVAE):
47
+ def __init__(self, vae_ckpt, vae_cfg, sr=48000) -> None:
48
+ super().__init__()
49
+ import json
50
+ with open(vae_cfg) as f:
51
+ config = json.load(f)
52
+ self.vae: AudioAutoencoder = create_autoencoder_from_config(config)
53
+ self.vae.load_state_dict(torch.load(vae_ckpt)['state_dict'])
54
+ self.sample_rate = sr
55
+ self.rsp48k = torchaudio.transforms.Resample(sr, self.orig_sample_rate) if sr != self.orig_sample_rate else nn.Identity()
56
+
57
+ @torch.no_grad()
58
+ def encode(self, wav: torch.Tensor, sample=True) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
59
+ wav = self.rsp48k(wav)
60
+ if wav.shape[-1] < 2048:
61
+ return torch.zeros((wav.shape[0], self.channel_dim, 0), device=wav.device, dtype=wav.dtype)
62
+ if wav.ndim == 2:
63
+ wav = wav.unsqueeze(1)
64
+ if wav.shape[1] == 1:
65
+ wav = wav.repeat(1, self.vae.in_channels, 1)
66
+ latent = self.vae.encode_audio(wav) # B, 64, T
67
+ return latent
68
+
69
+
70
+
71
+ def decode(self, latents: torch.Tensor, **kwargs):
72
+ # B, 64, T
73
+ with torch.no_grad():
74
+ audio_recon = self.vae.decode_audio(latents, **kwargs)
75
+
76
+ return audio_recon
77
+
78
+ @property
79
+ def frame_rate(self) -> float:
80
+ return float(self.vae.sample_rate) / self.vae.downsampling_ratio
81
+
82
+ @property
83
+ def orig_sample_rate(self) -> int:
84
+ return self.vae.sample_rate
85
+
86
+ @property
87
+ def channel_dim(self) -> int:
88
+ return self.vae.latent_dim
89
+
90
+ @property
91
+ def split_bands(self) -> int:
92
+ return 1
93
+
94
+ @property
95
+ def input_channel(self) -> int:
96
+ return self.vae.in_channels
SongBloom/models/vae_frontend/autoencoders.py ADDED
@@ -0,0 +1,657 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://github.com/Stability-AI/stable-audio-tools/tree/main/stable_audio_tools/models
2
+
3
+ import torch
4
+ import math
5
+ import numpy as np
6
+
7
+ from torch import nn
8
+ from torch.nn import functional as F
9
+ from torchaudio import transforms as T
10
+ from dac.nn.layers import WNConv1d, WNConvTranspose1d
11
+ from typing import Literal, Dict, Any
12
+ import os,sys
13
+
14
+ sys.path.insert(0, os.path.dirname(__file__))
15
+ from bottleneck import create_bottleneck_from_config
16
+
17
+ class Bottleneck(nn.Module):
18
+ def __init__(self, is_discrete: bool = False):
19
+ super().__init__()
20
+
21
+ self.is_discrete = is_discrete
22
+
23
+ def encode(self, x, return_info=False, **kwargs):
24
+ raise NotImplementedError
25
+
26
+ def decode(self, x):
27
+ raise NotImplementedError
28
+
29
+ class DiscreteBottleneck(Bottleneck):
30
+ def __init__(self, num_quantizers, codebook_size, tokens_id):
31
+ super().__init__(is_discrete=True)
32
+
33
+ self.num_quantizers = num_quantizers
34
+ self.codebook_size = codebook_size
35
+ self.tokens_id = tokens_id
36
+
37
+ def decode_tokens(self, codes, **kwargs):
38
+ raise NotImplementedError
39
+
40
+
41
+ def checkpoint(function, *args, **kwargs):
42
+ kwargs.setdefault("use_reentrant", False)
43
+ return torch.utils.checkpoint.checkpoint(function, *args, **kwargs)
44
+
45
+
46
+
47
+ # Adapted from https://github.com/NVIDIA/BigVGAN/blob/main/activations.py under MIT license
48
+ # License available in LICENSES/LICENSE_NVIDIA.txt
49
+ def snake_beta(x, alpha, beta):
50
+ return x + (1.0 / (beta + 1e-9)) * pow(torch.sin(x * alpha), 2)
51
+
52
+ class SnakeBeta(nn.Module):
53
+
54
+ def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True):
55
+ super(SnakeBeta, self).__init__()
56
+ self.in_features = in_features
57
+
58
+ # initialize alpha
59
+ self.alpha_logscale = alpha_logscale
60
+ if self.alpha_logscale: # log scale alphas initialized to zeros
61
+ self.alpha = nn.Parameter(torch.zeros(in_features) * alpha)
62
+ self.beta = nn.Parameter(torch.zeros(in_features) * alpha)
63
+ else: # linear scale alphas initialized to ones
64
+ self.alpha = nn.Parameter(torch.ones(in_features) * alpha)
65
+ self.beta = nn.Parameter(torch.ones(in_features) * alpha)
66
+
67
+ self.alpha.requires_grad = alpha_trainable
68
+ self.beta.requires_grad = alpha_trainable
69
+
70
+ self.no_div_by_zero = 1e-9
71
+
72
+ def forward(self, x):
73
+ alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
74
+ beta = self.beta.unsqueeze(0).unsqueeze(-1)
75
+ if self.alpha_logscale:
76
+ alpha = torch.exp(alpha)
77
+ beta = torch.exp(beta)
78
+ x = snake_beta(x, alpha, beta)
79
+
80
+ return x
81
+
82
+ def get_activation(activation: Literal["elu", "snake", "none"], antialias=False, channels=None) -> nn.Module:
83
+ if activation == "elu":
84
+ act = nn.ELU()
85
+ elif activation == "snake":
86
+ act = SnakeBeta(channels)
87
+ elif activation == "none":
88
+ act = nn.Identity()
89
+ else:
90
+ raise ValueError(f"Unknown activation {activation}")
91
+
92
+ if antialias:
93
+ from alias_free_torch import Activation1d
94
+ act = Activation1d(act)
95
+
96
+ return act
97
+
98
+ class ResidualUnit(nn.Module):
99
+ def __init__(self, in_channels, out_channels, dilation, use_snake=False, antialias_activation=False):
100
+ super().__init__()
101
+
102
+ self.dilation = dilation
103
+
104
+ padding = (dilation * (7-1)) // 2
105
+
106
+ self.layers = nn.Sequential(
107
+ get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=out_channels),
108
+ WNConv1d(in_channels=in_channels, out_channels=out_channels,
109
+ kernel_size=7, dilation=dilation, padding=padding),
110
+ get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=out_channels),
111
+ WNConv1d(in_channels=out_channels, out_channels=out_channels,
112
+ kernel_size=1)
113
+ )
114
+
115
+ def forward(self, x):
116
+ res = x
117
+
118
+ #x = checkpoint(self.layers, x)
119
+ x = self.layers(x)
120
+
121
+ return x + res
122
+
123
+ class EncoderBlock(nn.Module):
124
+ def __init__(self, in_channels, out_channels, stride, use_snake=False, antialias_activation=False):
125
+ super().__init__()
126
+
127
+ self.layers = nn.Sequential(
128
+ ResidualUnit(in_channels=in_channels,
129
+ out_channels=in_channels, dilation=1, use_snake=use_snake),
130
+ ResidualUnit(in_channels=in_channels,
131
+ out_channels=in_channels, dilation=3, use_snake=use_snake),
132
+ ResidualUnit(in_channels=in_channels,
133
+ out_channels=in_channels, dilation=9, use_snake=use_snake),
134
+ get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=in_channels),
135
+ WNConv1d(in_channels=in_channels, out_channels=out_channels,
136
+ kernel_size=2*stride, stride=stride, padding=math.ceil(stride/2)),
137
+ )
138
+
139
+ def forward(self, x):
140
+ return self.layers(x)
141
+
142
+ class DecoderBlock(nn.Module):
143
+ def __init__(self, in_channels, out_channels, stride, use_snake=False, antialias_activation=False, use_nearest_upsample=False):
144
+ super().__init__()
145
+
146
+ if use_nearest_upsample:
147
+ upsample_layer = nn.Sequential(
148
+ nn.Upsample(scale_factor=stride, mode="nearest"),
149
+ WNConv1d(in_channels=in_channels,
150
+ out_channels=out_channels,
151
+ kernel_size=2*stride,
152
+ stride=1,
153
+ bias=False,
154
+ padding='same')
155
+ )
156
+ else:
157
+ upsample_layer = WNConvTranspose1d(in_channels=in_channels,
158
+ out_channels=out_channels,
159
+ kernel_size=2*stride, stride=stride, padding=math.ceil(stride/2))
160
+
161
+ self.layers = nn.Sequential(
162
+ get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=in_channels),
163
+ upsample_layer,
164
+ ResidualUnit(in_channels=out_channels, out_channels=out_channels,
165
+ dilation=1, use_snake=use_snake),
166
+ ResidualUnit(in_channels=out_channels, out_channels=out_channels,
167
+ dilation=3, use_snake=use_snake),
168
+ ResidualUnit(in_channels=out_channels, out_channels=out_channels,
169
+ dilation=9, use_snake=use_snake),
170
+ )
171
+
172
+ def forward(self, x):
173
+ return self.layers(x)
174
+
175
+ class OobleckEncoder(nn.Module):
176
+ def __init__(self,
177
+ in_channels=2,
178
+ channels=128,
179
+ latent_dim=32,
180
+ c_mults = [1, 2, 4, 8],
181
+ strides = [2, 4, 8, 8],
182
+ use_snake=False,
183
+ antialias_activation=False
184
+ ):
185
+ super().__init__()
186
+
187
+ c_mults = [1] + c_mults
188
+
189
+ self.depth = len(c_mults)
190
+
191
+ layers = [
192
+ WNConv1d(in_channels=in_channels, out_channels=c_mults[0] * channels, kernel_size=7, padding=3)
193
+ ]
194
+
195
+ for i in range(self.depth-1):
196
+ layers += [EncoderBlock(in_channels=c_mults[i]*channels, out_channels=c_mults[i+1]*channels, stride=strides[i], use_snake=use_snake)]
197
+
198
+ layers += [
199
+ get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=c_mults[-1] * channels),
200
+ WNConv1d(in_channels=c_mults[-1]*channels, out_channels=latent_dim, kernel_size=3, padding=1)
201
+ ]
202
+
203
+ self.layers = nn.Sequential(*layers)
204
+
205
+ def forward(self, x):
206
+ return self.layers(x)
207
+
208
+
209
+ class OobleckDecoder(nn.Module):
210
+ def __init__(self,
211
+ out_channels=2,
212
+ channels=128,
213
+ latent_dim=32,
214
+ c_mults = [1, 2, 4, 8],
215
+ strides = [2, 4, 8, 8],
216
+ use_snake=False,
217
+ antialias_activation=False,
218
+ use_nearest_upsample=False,
219
+ final_tanh=True):
220
+ super().__init__()
221
+
222
+ c_mults = [1] + c_mults
223
+
224
+ self.depth = len(c_mults)
225
+
226
+ layers = [
227
+ WNConv1d(in_channels=latent_dim, out_channels=c_mults[-1]*channels, kernel_size=7, padding=3),
228
+ ]
229
+
230
+ for i in range(self.depth-1, 0, -1):
231
+ layers += [DecoderBlock(
232
+ in_channels=c_mults[i]*channels,
233
+ out_channels=c_mults[i-1]*channels,
234
+ stride=strides[i-1],
235
+ use_snake=use_snake,
236
+ antialias_activation=antialias_activation,
237
+ use_nearest_upsample=use_nearest_upsample
238
+ )
239
+ ]
240
+
241
+ layers += [
242
+ get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=c_mults[0] * channels),
243
+ WNConv1d(in_channels=c_mults[0] * channels, out_channels=out_channels, kernel_size=7, padding=3, bias=False),
244
+ nn.Tanh() if final_tanh else nn.Identity()
245
+ ]
246
+
247
+ self.layers = nn.Sequential(*layers)
248
+
249
+ def forward(self, x):
250
+ return self.layers(x)
251
+
252
+
253
+
254
+ class AudioAutoencoder(nn.Module):
255
+ def __init__(
256
+ self,
257
+ encoder,
258
+ decoder,
259
+ latent_dim,
260
+ downsampling_ratio,
261
+ sample_rate,
262
+ io_channels=2,
263
+ bottleneck = None,
264
+ pretransform = None,
265
+ in_channels = None,
266
+ out_channels = None,
267
+ soft_clip = False
268
+ ):
269
+ super().__init__()
270
+
271
+ self.downsampling_ratio = downsampling_ratio
272
+ self.sample_rate = sample_rate
273
+
274
+ self.latent_dim = latent_dim
275
+ self.io_channels = io_channels
276
+ self.in_channels = io_channels
277
+ self.out_channels = io_channels
278
+
279
+ self.min_length = self.downsampling_ratio
280
+
281
+ if in_channels is not None:
282
+ self.in_channels = in_channels
283
+
284
+ if out_channels is not None:
285
+ self.out_channels = out_channels
286
+
287
+ self.bottleneck = bottleneck
288
+
289
+ self.encoder = encoder
290
+
291
+ self.decoder = decoder
292
+
293
+ self.pretransform = pretransform
294
+
295
+ self.soft_clip = soft_clip
296
+
297
+ self.is_discrete = self.bottleneck is not None and self.bottleneck.is_discrete
298
+
299
+ def encode(self, audio, return_info=False, skip_pretransform=False, iterate_batch=False, **kwargs):
300
+
301
+ info = {}
302
+
303
+ if self.pretransform is not None and not skip_pretransform:
304
+ if self.pretransform.enable_grad:
305
+ if iterate_batch:
306
+ audios = []
307
+ for i in range(audio.shape[0]):
308
+ audios.append(self.pretransform.encode(audio[i:i+1]))
309
+ audio = torch.cat(audios, dim=0)
310
+ else:
311
+ audio = self.pretransform.encode(audio)
312
+ else:
313
+ with torch.no_grad():
314
+ if iterate_batch:
315
+ audios = []
316
+ for i in range(audio.shape[0]):
317
+ audios.append(self.pretransform.encode(audio[i:i+1]))
318
+ audio = torch.cat(audios, dim=0)
319
+ else:
320
+ audio = self.pretransform.encode(audio)
321
+
322
+ if self.encoder is not None:
323
+ if iterate_batch:
324
+ latents = []
325
+ for i in range(audio.shape[0]):
326
+ latents.append(self.encoder(audio[i:i+1]))
327
+ latents = torch.cat(latents, dim=0)
328
+ else:
329
+ latents = self.encoder(audio)
330
+ else:
331
+ latents = audio
332
+
333
+ if self.bottleneck is not None:
334
+ # TODO: Add iterate batch logic, needs to merge the info dicts
335
+ latents, bottleneck_info = self.bottleneck.encode(latents, return_info=True, **kwargs)
336
+
337
+ info.update(bottleneck_info)
338
+
339
+ if return_info:
340
+ return latents, info
341
+
342
+ return latents
343
+
344
+ def decode(self, latents, iterate_batch=False, **kwargs):
345
+
346
+ if self.bottleneck is not None:
347
+ if iterate_batch:
348
+ decoded = []
349
+ for i in range(latents.shape[0]):
350
+ decoded.append(self.bottleneck.decode(latents[i:i+1]))
351
+ latents = torch.cat(decoded, dim=0)
352
+ else:
353
+ latents = self.bottleneck.decode(latents)
354
+
355
+ if iterate_batch:
356
+ decoded = []
357
+ for i in range(latents.shape[0]):
358
+ decoded.append(self.decoder(latents[i:i+1]))
359
+ decoded = torch.cat(decoded, dim=0)
360
+ else:
361
+ decoded = self.decoder(latents, **kwargs)
362
+
363
+ if self.pretransform is not None:
364
+ if self.pretransform.enable_grad:
365
+ if iterate_batch:
366
+ decodeds = []
367
+ for i in range(decoded.shape[0]):
368
+ decodeds.append(self.pretransform.decode(decoded[i:i+1]))
369
+ decoded = torch.cat(decodeds, dim=0)
370
+ else:
371
+ decoded = self.pretransform.decode(decoded)
372
+ else:
373
+ with torch.no_grad():
374
+ if iterate_batch:
375
+ decodeds = []
376
+ for i in range(latents.shape[0]):
377
+ decodeds.append(self.pretransform.decode(decoded[i:i+1]))
378
+ decoded = torch.cat(decodeds, dim=0)
379
+ else:
380
+ decoded = self.pretransform.decode(decoded)
381
+
382
+ if self.soft_clip:
383
+ decoded = torch.tanh(decoded)
384
+
385
+ return decoded
386
+
387
+ def decode_tokens(self, tokens, **kwargs):
388
+ '''
389
+ Decode discrete tokens to audio
390
+ Only works with discrete autoencoders
391
+ '''
392
+
393
+ assert isinstance(self.bottleneck, DiscreteBottleneck), "decode_tokens only works with discrete autoencoders"
394
+
395
+ latents = self.bottleneck.decode_tokens(tokens, **kwargs)
396
+
397
+ return self.decode(latents, **kwargs)
398
+
399
+
400
+
401
+ def encode_audio(self, audio, chunked=False, overlap=32, chunk_size=128, **kwargs):
402
+ '''
403
+ Encode audios into latents. Audios should already be preprocesed by preprocess_audio_for_encoder.
404
+ If chunked is True, split the audio into chunks of a given maximum size chunk_size, with given overlap.
405
+ Overlap and chunk_size params are both measured in number of latents (not audio samples)
406
+ # and therefore you likely could use the same values with decode_audio.
407
+ A overlap of zero will cause discontinuity artefacts. Overlap should be => receptive field size.
408
+ Every autoencoder will have a different receptive field size, and thus ideal overlap.
409
+ You can determine it empirically by diffing unchunked vs chunked output and looking at maximum diff.
410
+ The final chunk may have a longer overlap in order to keep chunk_size consistent for all chunks.
411
+ Smaller chunk_size uses less memory, but more compute.
412
+ The chunk_size vs memory tradeoff isn't linear, and possibly depends on the GPU and CUDA version
413
+ For example, on a A6000 chunk_size 128 is overall faster than 256 and 512 even though it has more chunks
414
+ '''
415
+ if not chunked:
416
+ # default behavior. Encode the entire audio in parallel
417
+ return self.encode(audio, **kwargs)
418
+ else:
419
+ # CHUNKED ENCODING
420
+ # samples_per_latent is just the downsampling ratio (which is also the upsampling ratio)
421
+ samples_per_latent = self.downsampling_ratio
422
+ total_size = audio.shape[2] # in samples
423
+ batch_size = audio.shape[0]
424
+ chunk_size *= samples_per_latent # converting metric in latents to samples
425
+ overlap *= samples_per_latent # converting metric in latents to samples
426
+ hop_size = chunk_size - overlap
427
+ chunks = []
428
+ for i in range(0, total_size - chunk_size + 1, hop_size):
429
+ chunk = audio[:,:,i:i+chunk_size]
430
+ chunks.append(chunk)
431
+ if i+chunk_size != total_size:
432
+ # Final chunk
433
+ chunk = audio[:,:,-chunk_size:]
434
+ chunks.append(chunk)
435
+ chunks = torch.stack(chunks)
436
+ num_chunks = chunks.shape[0]
437
+ # Note: y_size might be a different value from the latent length used in diffusion training
438
+ # because we can encode audio of varying lengths
439
+ # However, the audio should've been padded to a multiple of samples_per_latent by now.
440
+ y_size = total_size // samples_per_latent
441
+ # Create an empty latent, we will populate it with chunks as we encode them
442
+ y_final = torch.zeros((batch_size,self.latent_dim,y_size)).to(audio.device)
443
+ for i in range(num_chunks):
444
+ x_chunk = chunks[i,:]
445
+ # encode the chunk
446
+ y_chunk = self.encode(x_chunk)
447
+ # figure out where to put the audio along the time domain
448
+ if i == num_chunks-1:
449
+ # final chunk always goes at the end
450
+ t_end = y_size
451
+ t_start = t_end - y_chunk.shape[2]
452
+ else:
453
+ t_start = i * hop_size // samples_per_latent
454
+ t_end = t_start + chunk_size // samples_per_latent
455
+ # remove the edges of the overlaps
456
+ ol = overlap//samples_per_latent//2
457
+ chunk_start = 0
458
+ chunk_end = y_chunk.shape[2]
459
+ if i > 0:
460
+ # no overlap for the start of the first chunk
461
+ t_start += ol
462
+ chunk_start += ol
463
+ if i < num_chunks-1:
464
+ # no overlap for the end of the last chunk
465
+ t_end -= ol
466
+ chunk_end -= ol
467
+ # paste the chunked audio into our y_final output audio
468
+ y_final[:,:,t_start:t_end] = y_chunk[:,:,chunk_start:chunk_end]
469
+ return y_final
470
+
471
+ def decode_audio(self, latents, chunked=False, overlap=32, chunk_size=128, **kwargs):
472
+ '''
473
+ Decode latents to audio.
474
+ If chunked is True, split the latents into chunks of a given maximum size chunk_size, with given overlap, both of which are measured in number of latents.
475
+ A overlap of zero will cause discontinuity artefacts. Overlap should be => receptive field size.
476
+ Every autoencoder will have a different receptive field size, and thus ideal overlap.
477
+ You can determine it empirically by diffing unchunked vs chunked audio and looking at maximum diff.
478
+ The final chunk may have a longer overlap in order to keep chunk_size consistent for all chunks.
479
+ Smaller chunk_size uses less memory, but more compute.
480
+ The chunk_size vs memory tradeoff isn't linear, and possibly depends on the GPU and CUDA version
481
+ For example, on a A6000 chunk_size 128 is overall faster than 256 and 512 even though it has more chunks
482
+ '''
483
+ if not chunked:
484
+ # default behavior. Decode the entire latent in parallel
485
+ return self.decode(latents, **kwargs)
486
+ else:
487
+ # chunked decoding
488
+ hop_size = chunk_size - overlap
489
+ total_size = latents.shape[2]
490
+ batch_size = latents.shape[0]
491
+ chunks = []
492
+ for i in range(0, total_size - chunk_size + 1, hop_size):
493
+ chunk = latents[:,:,i:i+chunk_size]
494
+ chunks.append(chunk)
495
+ if i+chunk_size != total_size:
496
+ # Final chunk
497
+ chunk = latents[:,:,-chunk_size:]
498
+ chunks.append(chunk)
499
+ chunks = torch.stack(chunks)
500
+ num_chunks = chunks.shape[0]
501
+ # samples_per_latent is just the downsampling ratio
502
+ samples_per_latent = self.downsampling_ratio
503
+ # Create an empty waveform, we will populate it with chunks as decode them
504
+ y_size = total_size * samples_per_latent
505
+ y_final = torch.zeros((batch_size,self.out_channels,y_size)).to(latents.device)
506
+ for i in range(num_chunks):
507
+ x_chunk = chunks[i,:]
508
+ # decode the chunk
509
+ y_chunk = self.decode(x_chunk)
510
+ # figure out where to put the audio along the time domain
511
+ if i == num_chunks-1:
512
+ # final chunk always goes at the end
513
+ t_end = y_size
514
+ t_start = t_end - y_chunk.shape[2]
515
+ else:
516
+ t_start = i * hop_size * samples_per_latent
517
+ t_end = t_start + chunk_size * samples_per_latent
518
+ # remove the edges of the overlaps
519
+ ol = (overlap//2) * samples_per_latent
520
+ chunk_start = 0
521
+ chunk_end = y_chunk.shape[2]
522
+ if i > 0:
523
+ # no overlap for the start of the first chunk
524
+ t_start += ol
525
+ chunk_start += ol
526
+ if i < num_chunks-1:
527
+ # no overlap for the end of the last chunk
528
+ t_end -= ol
529
+ chunk_end -= ol
530
+ # paste the chunked audio into our y_final output audio
531
+ y_final[:,:,t_start:t_end] = y_chunk[:,:,chunk_start:chunk_end]
532
+ return y_final
533
+
534
+
535
+ # AE factories
536
+
537
+ def create_encoder_from_config(encoder_config: Dict[str, Any]):
538
+ encoder_type = encoder_config.get("type", None)
539
+ assert encoder_type is not None, "Encoder type must be specified"
540
+
541
+ if encoder_type == "oobleck":
542
+ encoder = OobleckEncoder(
543
+ **encoder_config["config"]
544
+ )
545
+
546
+ elif encoder_type == "seanet":
547
+ from encodec.modules import SEANetEncoder
548
+ seanet_encoder_config = encoder_config["config"]
549
+
550
+ #SEANet encoder expects strides in reverse order
551
+ seanet_encoder_config["ratios"] = list(reversed(seanet_encoder_config.get("ratios", [2, 2, 2, 2, 2])))
552
+ encoder = SEANetEncoder(
553
+ **seanet_encoder_config
554
+ )
555
+ else:
556
+ raise ValueError(f"Unknown encoder type {encoder_type}")
557
+
558
+ requires_grad = encoder_config.get("requires_grad", True)
559
+ if not requires_grad:
560
+ for param in encoder.parameters():
561
+ param.requires_grad = False
562
+
563
+ return encoder
564
+
565
+ def create_decoder_from_config(decoder_config: Dict[str, Any]):
566
+ decoder_type = decoder_config.get("type", None)
567
+ assert decoder_type is not None, "Decoder type must be specified"
568
+
569
+ if decoder_type == "oobleck":
570
+ decoder = OobleckDecoder(
571
+ **decoder_config["config"]
572
+ )
573
+ elif decoder_type == "seanet":
574
+ from encodec.modules import SEANetDecoder
575
+
576
+ decoder = SEANetDecoder(
577
+ **decoder_config["config"]
578
+ )
579
+ else:
580
+ raise ValueError(f"Unknown decoder type {decoder_type}")
581
+
582
+ requires_grad = decoder_config.get("requires_grad", True)
583
+ if not requires_grad:
584
+ for param in decoder.parameters():
585
+ param.requires_grad = False
586
+
587
+ return decoder
588
+
589
+ def create_autoencoder_from_config(config: Dict[str, Any]):
590
+
591
+ # print(config)
592
+ ae_config = config["model"]
593
+
594
+ encoder = create_encoder_from_config(ae_config["encoder"])
595
+ decoder = create_decoder_from_config(ae_config["decoder"])
596
+
597
+ bottleneck = ae_config.get("bottleneck", None)
598
+
599
+ latent_dim = ae_config.get("latent_dim", None)
600
+ assert latent_dim is not None, "latent_dim must be specified in model config"
601
+ downsampling_ratio = ae_config.get("downsampling_ratio", None)
602
+ assert downsampling_ratio is not None, "downsampling_ratio must be specified in model config"
603
+ io_channels = ae_config.get("io_channels", None)
604
+ assert io_channels is not None, "io_channels must be specified in model config"
605
+ sample_rate = config.get("sample_rate", None)
606
+ assert sample_rate is not None, "sample_rate must be specified in model config"
607
+
608
+ in_channels = ae_config.get("in_channels", None)
609
+ out_channels = ae_config.get("out_channels", None)
610
+
611
+ pretransform = ae_config.get("pretransform", None)
612
+
613
+ if pretransform is not None:
614
+ from stable_audio_tools.models.factory import create_pretransform_from_config
615
+ pretransform = create_pretransform_from_config(pretransform, sample_rate)
616
+
617
+ if bottleneck is not None:
618
+ bottleneck = create_bottleneck_from_config(bottleneck)
619
+
620
+ soft_clip = ae_config["decoder"].get("soft_clip", False)
621
+
622
+ return AudioAutoencoder(
623
+ encoder,
624
+ decoder,
625
+ io_channels=io_channels,
626
+ latent_dim=latent_dim,
627
+ downsampling_ratio=downsampling_ratio,
628
+ sample_rate=sample_rate,
629
+ bottleneck=bottleneck,
630
+ pretransform=pretransform,
631
+ in_channels=in_channels,
632
+ out_channels=out_channels,
633
+ soft_clip=soft_clip
634
+ )
635
+
636
+
637
+
638
+ if __name__ == "__main__":
639
+ import json
640
+ import torchaudio
641
+ config_path = 'modelzoo/stable_audio_vae/stable_audio_2_0_vae.json'
642
+ with open(config_path) as f:
643
+ config = json.load(f)
644
+ with torch.no_grad():
645
+ vae_model = create_autoencoder_from_config(config).cuda()
646
+ model_ckpt_path = 'modelzoo/stable_audio_vae/autoencoder.ckpt'
647
+ vae_model.load_state_dict(torch.load(model_ckpt_path)['state_dict'])
648
+
649
+
650
+ input_audios, sr = torchaudio.load("music_example/加勒比海盗 主题.wav")
651
+ input_audios = torchaudio.functional.resample(input_audios, sr, 48000)[...,:2048]
652
+ input_audios = input_audios.unsqueeze(1).repeat(1, 2, 1).cuda()
653
+ latents = vae_model.encode_audio(input_audios)
654
+ recover_audio = vae_model.decode_audio(latents)
655
+ print(recover_audio)
656
+
657
+ breakpoint()