Spaces:
Running
on
L40S
Running
on
L40S
fix some typo
Browse files- LICENSE +211 -211
- README.md +0 -2
- app.py +3 -3
- codeclm/models/codeclm.py +40 -53
- codeclm/tokenizer/Flow1dVAE/generate_septoken.py +3 -2
- codeclm/tokenizer/Flow1dVAE/libs/rvq/core_vq.py +366 -366
- codeclm/tokenizer/Flow1dVAE/model_1rvq.py +710 -710
- codeclm/tokenizer/Flow1dVAE/model_2rvq.py +774 -774
- codeclm/tokenizer/Flow1dVAE/model_4rvq.py +774 -774
- codeclm/tokenizer/Flow1dVAE/model_septoken.py +670 -670
- codeclm/tokenizer/Flow1dVAE/models/unet_2d_condition_additionalemb.py +0 -0
- codeclm/tokenizer/Flow1dVAE/models/unet_2d_condition_flow.py +0 -0
- codeclm/tokenizer/Flow1dVAE/models_gpt/models/tokenizer/pinyin/symbols.py +71 -71
- codeclm/tokenizer/Flow1dVAE/tools/infer_bsrnnvae441k.py +47 -47
- codeclm/tokenizer/Flow1dVAE/tools/infer_bsrnnvae441k_vocal.py +47 -47
- codeclm/tokenizer/Flow1dVAE/tools/infer_hifigan48k_speech.py +56 -56
- codeclm/tokenizer/Flow1dVAE/tools/infer_hifigan48k_vocal.py +57 -57
- codeclm/tokenizer/Flow1dVAE/tools/infer_vaehifigan48k.py +59 -59
- codeclm/tokenizer/Flow1dVAE/tools/infer_vaehifigan48k_soundmusic.py +61 -61
- codeclm/tokenizer/Flow1dVAE/tools/infer_vaehifigan48k_speech.py +58 -58
- codeclm/tokenizer/Flow1dVAE/tools/infer_vaehifigan48k_vocal.py +59 -59
- codeclm/tokenizer/Flow1dVAE/tools/mix.py +50 -50
- codeclm/tokenizer/Flow1dVAE/tools/torch_tools.py +142 -142
- codeclm/tokenizer/audio_tokenizer.py +2 -2
- generate_lowmem.py +240 -0
- generate_lowmem.sh +10 -0
- requirements.txt +24 -0
- requirements_nodeps.txt +13 -0
- sample/lyrics.jsonl +1 -1
- tools/gradio/app.py +236 -0
- tools/gradio/levo_inference.py +110 -0
- tools/gradio/levo_inference_lowmem.py +129 -0
- tools/gradio/run.sh +9 -0
- tools/gradio/separator.py +50 -0
LICENSE
CHANGED
@@ -1,211 +1,211 @@
|
|
1 |
-
Tencent is pleased to support the open source community by making SongGeneration available.
|
2 |
-
|
3 |
-
Copyright (C) 2025 Tencent. All rights reserved.
|
4 |
-
|
5 |
-
SongGeneration is licensed under the License Terms of SongGeneration except for the third-party components listed below, which is licensed under different terms. SongGeneration does not impose any additional limitations beyond what is outlined in the respective licenses of these third-party components. Users must comply with all terms and conditions of original licenses of these third-party components and must ensure that the usage of the third party components adheres to all relevant laws and regulations.
|
6 |
-
|
7 |
-
|
8 |
-
License Terms of SongGeneration:
|
9 |
-
--------------------------------------------------------------------
|
10 |
-
|
11 |
-
Permission is hereby granted, free of charge, to any person obtaining a copy of this Software and associated documentation files, to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, and/or sublicense copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
12 |
-
|
13 |
-
- You agree to use the SongGeneration only for academic, research and education purposes, and refrain from using it for any commercial or production purposes under any circumstances.
|
14 |
-
|
15 |
-
- The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
|
16 |
-
|
17 |
-
For avoidance of doubts, "Software" means the SongGeneration inference-enabling code and the weights made available under this license excluding any pre-trained data and other AI components.
|
18 |
-
|
19 |
-
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
20 |
-
|
21 |
-
|
22 |
-
Other dependencies and licenses:
|
23 |
-
|
24 |
-
|
25 |
-
Open Source Software Licensed under the MIT License and Other Licenses of the Third-Party Components therein:
|
26 |
-
--------------------------------------------------------------------
|
27 |
-
1. stable_audio_tools
|
28 |
-
Copyright (c) 2023 Stability AI
|
29 |
-
|
30 |
-
|
31 |
-
Terms of the MIT:
|
32 |
-
--------------------------------------------------------------------
|
33 |
-
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
34 |
-
|
35 |
-
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
|
36 |
-
|
37 |
-
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
38 |
-
|
39 |
-
For the license of other third party components, please refer to the following URL:
|
40 |
-
https://github.com/Stability-AI/stable-audio-tools/tree/main/LICENSES
|
41 |
-
|
42 |
-
|
43 |
-
Open Source Software Licensed under the MIT License:
|
44 |
-
--------------------------------------------------------------------
|
45 |
-
1. demucs
|
46 |
-
Copyright (c) Meta Platforms, Inc. and affiliates.
|
47 |
-
|
48 |
-
|
49 |
-
A copy of the MIT is included in this file.
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
Open Source Software Licensed under the BSD 3-Clause License and Other Licenses of the Third-Party Components therein:
|
54 |
-
--------------------------------------------------------------------
|
55 |
-
1. torch
|
56 |
-
From PyTorch:
|
57 |
-
|
58 |
-
Copyright (c) 2016- Facebook, Inc (Adam Paszke)
|
59 |
-
Copyright (c) 2014- Facebook, Inc (Soumith Chintala)
|
60 |
-
Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert)
|
61 |
-
Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu)
|
62 |
-
Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu)
|
63 |
-
Copyright (c) 2011-2013 NYU (Clement Farabet)
|
64 |
-
Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston)
|
65 |
-
Copyright (c) 2006 Idiap Research Institute (Samy Bengio)
|
66 |
-
Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz)
|
67 |
-
|
68 |
-
From Caffe2:
|
69 |
-
|
70 |
-
Copyright (c) 2016-present, Facebook Inc. All rights reserved.
|
71 |
-
|
72 |
-
All contributions by Facebook:
|
73 |
-
Copyright (c) 2016 Facebook Inc.
|
74 |
-
|
75 |
-
All contributions by Google:
|
76 |
-
Copyright (c) 2015 Google Inc.
|
77 |
-
All rights reserved.
|
78 |
-
|
79 |
-
All contributions by Yangqing Jia:
|
80 |
-
Copyright (c) 2015 Yangqing Jia
|
81 |
-
All rights reserved.
|
82 |
-
|
83 |
-
All contributions by Kakao Brain:
|
84 |
-
Copyright 2019-2020 Kakao Brain
|
85 |
-
|
86 |
-
All contributions by Cruise LLC:
|
87 |
-
Copyright (c) 2022 Cruise LLC.
|
88 |
-
All rights reserved.
|
89 |
-
|
90 |
-
All contributions from Caffe:
|
91 |
-
Copyright(c) 2013, 2014, 2015, the respective contributors
|
92 |
-
All rights reserved.
|
93 |
-
|
94 |
-
All other contributions:
|
95 |
-
Copyright(c) 2015, 2016 the respective contributors
|
96 |
-
All rights reserved.
|
97 |
-
|
98 |
-
Caffe2 uses a copyright model similar to Caffe: each contributor holds
|
99 |
-
copyright over their contributions to Caffe2. The project versioning records
|
100 |
-
all such contribution and copyright details. If a contributor wants to further
|
101 |
-
mark their specific copyright on a particular contribution, they should
|
102 |
-
indicate their copyright solely in the commit message of the change when it is
|
103 |
-
committed.
|
104 |
-
|
105 |
-
All rights reserved.
|
106 |
-
|
107 |
-
|
108 |
-
Terms of the BSD 3-Clause:
|
109 |
-
--------------------------------------------------------------------
|
110 |
-
Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
|
111 |
-
|
112 |
-
1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
|
113 |
-
|
114 |
-
2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
|
115 |
-
|
116 |
-
3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
|
117 |
-
|
118 |
-
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
119 |
-
|
120 |
-
For the license of other third party components, please refer to the following URL:
|
121 |
-
https://github.com/pytorch/pytorch/blob/v2.0.1/NOTICE
|
122 |
-
|
123 |
-
|
124 |
-
Open Source Software Licensed under the BSD 2-Clause License and Other Licenses of the Third-Party Components therein:
|
125 |
-
--------------------------------------------------------------------
|
126 |
-
1. torchaudio
|
127 |
-
Copyright (c) 2017 Facebook Inc. (Soumith Chintala),
|
128 |
-
All rights reserved.
|
129 |
-
|
130 |
-
|
131 |
-
Terms of the BSD 2-Clause:
|
132 |
-
--------------------------------------------------------------------
|
133 |
-
Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
|
134 |
-
|
135 |
-
* Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
|
136 |
-
|
137 |
-
* Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
|
138 |
-
|
139 |
-
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
140 |
-
|
141 |
-
For the license of other third party components, please refer to the following URL:
|
142 |
-
https://github.com/pytorch/audio/blob/v2.0.2/LICENSE
|
143 |
-
|
144 |
-
|
145 |
-
Open Source Software License under the Apache License Version 2.0:
|
146 |
-
--------------------------------------------------------------------
|
147 |
-
1. huggingface-hub
|
148 |
-
Copyright (c) huggingface-hub original author and authors
|
149 |
-
|
150 |
-
2. transformers
|
151 |
-
Copyright 2018- The Hugging Face team. All rights reserved.
|
152 |
-
|
153 |
-
|
154 |
-
Terms of the Apache License Version 2.0:
|
155 |
-
--------------------------------------------------------------------
|
156 |
-
Apache License
|
157 |
-
|
158 |
-
Version 2.0, January 2004
|
159 |
-
|
160 |
-
http://www.apache.org/licenses/
|
161 |
-
|
162 |
-
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
163 |
-
1. Definitions.
|
164 |
-
|
165 |
-
"License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document.
|
166 |
-
|
167 |
-
"Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License.
|
168 |
-
|
169 |
-
"Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity.
|
170 |
-
|
171 |
-
"You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License.
|
172 |
-
|
173 |
-
"Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files.
|
174 |
-
|
175 |
-
"Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types.
|
176 |
-
|
177 |
-
"Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below).
|
178 |
-
|
179 |
-
"Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof.
|
180 |
-
|
181 |
-
"Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution."
|
182 |
-
|
183 |
-
"Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work.
|
184 |
-
|
185 |
-
2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form.
|
186 |
-
|
187 |
-
3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed.
|
188 |
-
|
189 |
-
4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions:
|
190 |
-
|
191 |
-
You must give any other recipients of the Work or Derivative Works a copy of this License; and
|
192 |
-
|
193 |
-
You must cause any modified files to carry prominent notices stating that You changed the files; and
|
194 |
-
|
195 |
-
You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and
|
196 |
-
|
197 |
-
If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License.
|
198 |
-
|
199 |
-
You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License.
|
200 |
-
|
201 |
-
5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions.
|
202 |
-
|
203 |
-
6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file.
|
204 |
-
|
205 |
-
7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License.
|
206 |
-
|
207 |
-
8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages.
|
208 |
-
|
209 |
-
9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability.
|
210 |
-
|
211 |
-
END OF TERMS AND CONDITIONS
|
|
|
1 |
+
Tencent is pleased to support the open source community by making SongGeneration available.
|
2 |
+
|
3 |
+
Copyright (C) 2025 Tencent. All rights reserved.
|
4 |
+
|
5 |
+
SongGeneration is licensed under the License Terms of SongGeneration except for the third-party components listed below, which is licensed under different terms. SongGeneration does not impose any additional limitations beyond what is outlined in the respective licenses of these third-party components. Users must comply with all terms and conditions of original licenses of these third-party components and must ensure that the usage of the third party components adheres to all relevant laws and regulations.
|
6 |
+
|
7 |
+
|
8 |
+
License Terms of SongGeneration:
|
9 |
+
--------------------------------------------------------------------
|
10 |
+
|
11 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy of this Software and associated documentation files, to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, and/or sublicense copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
12 |
+
|
13 |
+
- You agree to use the SongGeneration only for academic, research and education purposes, and refrain from using it for any commercial or production purposes under any circumstances.
|
14 |
+
|
15 |
+
- The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
|
16 |
+
|
17 |
+
For avoidance of doubts, "Software" means the SongGeneration inference-enabling code and the weights made available under this license excluding any pre-trained data and other AI components.
|
18 |
+
|
19 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
20 |
+
|
21 |
+
|
22 |
+
Other dependencies and licenses:
|
23 |
+
|
24 |
+
|
25 |
+
Open Source Software Licensed under the MIT License and Other Licenses of the Third-Party Components therein:
|
26 |
+
--------------------------------------------------------------------
|
27 |
+
1. stable_audio_tools
|
28 |
+
Copyright (c) 2023 Stability AI
|
29 |
+
|
30 |
+
|
31 |
+
Terms of the MIT:
|
32 |
+
--------------------------------------------------------------------
|
33 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
34 |
+
|
35 |
+
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
|
36 |
+
|
37 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
38 |
+
|
39 |
+
For the license of other third party components, please refer to the following URL:
|
40 |
+
https://github.com/Stability-AI/stable-audio-tools/tree/main/LICENSES
|
41 |
+
|
42 |
+
|
43 |
+
Open Source Software Licensed under the MIT License:
|
44 |
+
--------------------------------------------------------------------
|
45 |
+
1. demucs
|
46 |
+
Copyright (c) Meta Platforms, Inc. and affiliates.
|
47 |
+
|
48 |
+
|
49 |
+
A copy of the MIT is included in this file.
|
50 |
+
|
51 |
+
|
52 |
+
|
53 |
+
Open Source Software Licensed under the BSD 3-Clause License and Other Licenses of the Third-Party Components therein:
|
54 |
+
--------------------------------------------------------------------
|
55 |
+
1. torch
|
56 |
+
From PyTorch:
|
57 |
+
|
58 |
+
Copyright (c) 2016- Facebook, Inc (Adam Paszke)
|
59 |
+
Copyright (c) 2014- Facebook, Inc (Soumith Chintala)
|
60 |
+
Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert)
|
61 |
+
Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu)
|
62 |
+
Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu)
|
63 |
+
Copyright (c) 2011-2013 NYU (Clement Farabet)
|
64 |
+
Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston)
|
65 |
+
Copyright (c) 2006 Idiap Research Institute (Samy Bengio)
|
66 |
+
Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz)
|
67 |
+
|
68 |
+
From Caffe2:
|
69 |
+
|
70 |
+
Copyright (c) 2016-present, Facebook Inc. All rights reserved.
|
71 |
+
|
72 |
+
All contributions by Facebook:
|
73 |
+
Copyright (c) 2016 Facebook Inc.
|
74 |
+
|
75 |
+
All contributions by Google:
|
76 |
+
Copyright (c) 2015 Google Inc.
|
77 |
+
All rights reserved.
|
78 |
+
|
79 |
+
All contributions by Yangqing Jia:
|
80 |
+
Copyright (c) 2015 Yangqing Jia
|
81 |
+
All rights reserved.
|
82 |
+
|
83 |
+
All contributions by Kakao Brain:
|
84 |
+
Copyright 2019-2020 Kakao Brain
|
85 |
+
|
86 |
+
All contributions by Cruise LLC:
|
87 |
+
Copyright (c) 2022 Cruise LLC.
|
88 |
+
All rights reserved.
|
89 |
+
|
90 |
+
All contributions from Caffe:
|
91 |
+
Copyright(c) 2013, 2014, 2015, the respective contributors
|
92 |
+
All rights reserved.
|
93 |
+
|
94 |
+
All other contributions:
|
95 |
+
Copyright(c) 2015, 2016 the respective contributors
|
96 |
+
All rights reserved.
|
97 |
+
|
98 |
+
Caffe2 uses a copyright model similar to Caffe: each contributor holds
|
99 |
+
copyright over their contributions to Caffe2. The project versioning records
|
100 |
+
all such contribution and copyright details. If a contributor wants to further
|
101 |
+
mark their specific copyright on a particular contribution, they should
|
102 |
+
indicate their copyright solely in the commit message of the change when it is
|
103 |
+
committed.
|
104 |
+
|
105 |
+
All rights reserved.
|
106 |
+
|
107 |
+
|
108 |
+
Terms of the BSD 3-Clause:
|
109 |
+
--------------------------------------------------------------------
|
110 |
+
Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
|
111 |
+
|
112 |
+
1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
|
113 |
+
|
114 |
+
2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
|
115 |
+
|
116 |
+
3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
|
117 |
+
|
118 |
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
119 |
+
|
120 |
+
For the license of other third party components, please refer to the following URL:
|
121 |
+
https://github.com/pytorch/pytorch/blob/v2.0.1/NOTICE
|
122 |
+
|
123 |
+
|
124 |
+
Open Source Software Licensed under the BSD 2-Clause License and Other Licenses of the Third-Party Components therein:
|
125 |
+
--------------------------------------------------------------------
|
126 |
+
1. torchaudio
|
127 |
+
Copyright (c) 2017 Facebook Inc. (Soumith Chintala),
|
128 |
+
All rights reserved.
|
129 |
+
|
130 |
+
|
131 |
+
Terms of the BSD 2-Clause:
|
132 |
+
--------------------------------------------------------------------
|
133 |
+
Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
|
134 |
+
|
135 |
+
* Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
|
136 |
+
|
137 |
+
* Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
|
138 |
+
|
139 |
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
140 |
+
|
141 |
+
For the license of other third party components, please refer to the following URL:
|
142 |
+
https://github.com/pytorch/audio/blob/v2.0.2/LICENSE
|
143 |
+
|
144 |
+
|
145 |
+
Open Source Software License under the Apache License Version 2.0:
|
146 |
+
--------------------------------------------------------------------
|
147 |
+
1. huggingface-hub
|
148 |
+
Copyright (c) huggingface-hub original author and authors
|
149 |
+
|
150 |
+
2. transformers
|
151 |
+
Copyright 2018- The Hugging Face team. All rights reserved.
|
152 |
+
|
153 |
+
|
154 |
+
Terms of the Apache License Version 2.0:
|
155 |
+
--------------------------------------------------------------------
|
156 |
+
Apache License
|
157 |
+
|
158 |
+
Version 2.0, January 2004
|
159 |
+
|
160 |
+
http://www.apache.org/licenses/
|
161 |
+
|
162 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
163 |
+
1. Definitions.
|
164 |
+
|
165 |
+
"License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document.
|
166 |
+
|
167 |
+
"Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License.
|
168 |
+
|
169 |
+
"Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity.
|
170 |
+
|
171 |
+
"You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License.
|
172 |
+
|
173 |
+
"Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files.
|
174 |
+
|
175 |
+
"Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types.
|
176 |
+
|
177 |
+
"Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below).
|
178 |
+
|
179 |
+
"Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof.
|
180 |
+
|
181 |
+
"Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution."
|
182 |
+
|
183 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work.
|
184 |
+
|
185 |
+
2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form.
|
186 |
+
|
187 |
+
3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed.
|
188 |
+
|
189 |
+
4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions:
|
190 |
+
|
191 |
+
You must give any other recipients of the Work or Derivative Works a copy of this License; and
|
192 |
+
|
193 |
+
You must cause any modified files to carry prominent notices stating that You changed the files; and
|
194 |
+
|
195 |
+
You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and
|
196 |
+
|
197 |
+
If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License.
|
198 |
+
|
199 |
+
You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License.
|
200 |
+
|
201 |
+
5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions.
|
202 |
+
|
203 |
+
6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file.
|
204 |
+
|
205 |
+
7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License.
|
206 |
+
|
207 |
+
8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages.
|
208 |
+
|
209 |
+
9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability.
|
210 |
+
|
211 |
+
END OF TERMS AND CONDITIONS
|
README.md
CHANGED
@@ -7,11 +7,9 @@ sdk: docker
|
|
7 |
app_port: 7860
|
8 |
---
|
9 |
|
10 |
-
|
11 |
<p align="center">
|
12 |
<a href="https://levo-demo.github.io/">Demo</a> | <a href="https://arxiv.org/abs/2506.07520">Paper</a> | <a href="https://github.com/tencent-ailab/songgeneration">Code</a>
|
13 |
</p>
|
14 |
-
|
15 |
This repository is the official weight repository for LeVo: High-Quality Song Generation with Multi-Preference Alignment. In this repository, we provide the SongGeneration model, inference scripts, and the checkpoint that has been trained on the Million Song Dataset.
|
16 |
|
17 |
## Overview
|
|
|
7 |
app_port: 7860
|
8 |
---
|
9 |
|
|
|
10 |
<p align="center">
|
11 |
<a href="https://levo-demo.github.io/">Demo</a> | <a href="https://arxiv.org/abs/2506.07520">Paper</a> | <a href="https://github.com/tencent-ailab/songgeneration">Code</a>
|
12 |
</p>
|
|
|
13 |
This repository is the official weight repository for LeVo: High-Quality Song Generation with Multi-Preference Alignment. In this repository, we provide the SongGeneration model, inference scripts, and the checkpoint that has been trained on the Million Song Dataset.
|
14 |
|
15 |
## Overview
|
app.py
CHANGED
@@ -124,9 +124,9 @@ def generate_song(lyric, description=None, prompt_audio=None, genre=None, cfg_co
|
|
124 |
|
125 |
|
126 |
# 创建Gradio界面
|
127 |
-
with gr.Blocks(title="
|
128 |
-
gr.Markdown("# 🎵
|
129 |
-
gr.Markdown("Demo interface for the song generation model. Provide a lyrics, and optionally an audio or text prompt, to generate a custom song.")
|
130 |
|
131 |
with gr.Row():
|
132 |
with gr.Column():
|
|
|
124 |
|
125 |
|
126 |
# 创建Gradio界面
|
127 |
+
with gr.Blocks(title="SongGeneration Demo Space") as demo:
|
128 |
+
gr.Markdown("# 🎵 SongGeneration Demo Space")
|
129 |
+
gr.Markdown("Demo interface for the song generation model. Provide a lyrics, and optionally an audio or text prompt, to generate a custom song. The code is in [GIT](https://github.com/tencent-ailab/SongGeneration)")
|
130 |
|
131 |
with gr.Row():
|
132 |
with gr.Column():
|
codeclm/models/codeclm.py
CHANGED
@@ -36,6 +36,10 @@ class CodecLM:
|
|
36 |
max_duration: tp.Optional[float] = None, seperate_tokenizer: AudioTokenizer = None):
|
37 |
self.name = name
|
38 |
self.audiotokenizer = audiotokenizer
|
|
|
|
|
|
|
|
|
39 |
self.lm = lm
|
40 |
self.seperate_tokenizer = seperate_tokenizer
|
41 |
# import pdb; pdb.set_trace()
|
@@ -47,7 +51,7 @@ class CodecLM:
|
|
47 |
assert max_duration is not None
|
48 |
|
49 |
self.max_duration: float = max_duration
|
50 |
-
self.device =
|
51 |
self.generation_params: dict = {}
|
52 |
# self.set_generation_params(duration=15) # 15 seconds by default
|
53 |
self.set_generation_params(duration=15, extend_stride=self.max_duration // 2)
|
@@ -57,23 +61,6 @@ class CodecLM:
|
|
57 |
else:
|
58 |
self.autocast = TorchAutocast(enabled=False)
|
59 |
|
60 |
-
|
61 |
-
|
62 |
-
@property
|
63 |
-
def frame_rate(self) -> float:
|
64 |
-
"""Roughly the number of AR steps per seconds."""
|
65 |
-
return self.audiotokenizer.frame_rate
|
66 |
-
|
67 |
-
@property
|
68 |
-
def sample_rate(self) -> int:
|
69 |
-
"""Sample rate of the generated audio."""
|
70 |
-
return self.audiotokenizer.sample_rate
|
71 |
-
|
72 |
-
@property
|
73 |
-
def audio_channels(self) -> int:
|
74 |
-
"""Audio channels of the generated audio."""
|
75 |
-
return self.audiotokenizer.channels
|
76 |
-
|
77 |
def set_generation_params(self, use_sampling: bool = True, top_k: int = 250,
|
78 |
top_p: float = 0.0, temperature: float = 1.0,
|
79 |
duration: float = 30.0, cfg_coef: float = 3.0,
|
@@ -185,7 +172,7 @@ class CodecLM:
|
|
185 |
assert len(lyrics) == 1
|
186 |
texts = [lyric for lyric in lyrics]
|
187 |
audio_qt_embs = []
|
188 |
-
target_melody_token_len = self.lm.cfg.prompt_len * self.
|
189 |
# import pdb; pdb.set_trace()
|
190 |
if melody_wavs is None:
|
191 |
melody_tokens = torch.full((1,1,target_melody_token_len), 16385, device=self.device).long()
|
@@ -207,39 +194,39 @@ class CodecLM:
|
|
207 |
melody_tokens = melody_tokens[...,:target_melody_token_len]
|
208 |
elif melody_tokens.shape[-1] < target_melody_token_len:
|
209 |
melody_tokens = torch.cat([melody_tokens, torch.full((1,1,target_melody_token_len - melody_tokens.shape[-1]), 16385, device=self.device).long()], dim=-1)
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
215 |
else:
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
f"got vocal len={vocal_tokens.shape[-1]}, and bgm len={bgm_tokens.shape[-1]}"
|
234 |
-
if bgm_tokens.shape[-1] > target_melody_token_len:
|
235 |
-
bgm_tokens = bgm_tokens[...,:target_melody_token_len]
|
236 |
-
elif bgm_tokens.shape[-1] < target_melody_token_len:
|
237 |
-
bgm_tokens = torch.cat([bgm_tokens, torch.full((1,1,target_melody_token_len - bgm_tokens.shape[-1]), 16385, device=self.device).long()], dim=-1)
|
238 |
-
if vocal_tokens.shape[-1] > target_melody_token_len:
|
239 |
-
vocal_tokens = vocal_tokens[...,:target_melody_token_len]
|
240 |
-
elif vocal_tokens.shape[-1] < target_melody_token_len:
|
241 |
-
vocal_tokens = torch.cat([vocal_tokens, torch.full((1,1,target_melody_token_len - vocal_tokens.shape[-1]), 16385, device=self.device).long()], dim=-1)
|
242 |
-
melody_tokens = torch.cat([melody_tokens, vocal_tokens, bgm_tokens], dim=1)
|
243 |
assert melody_tokens.shape[-1] == target_melody_token_len
|
244 |
audio_qt_embs = melody_tokens.long()
|
245 |
return texts, audio_qt_embs
|
@@ -284,7 +271,7 @@ class CodecLM:
|
|
284 |
return gen_tokens
|
285 |
|
286 |
@torch.no_grad()
|
287 |
-
def generate_audio(self, gen_tokens: torch.Tensor, prompt=None, vocal_prompt=None, bgm_prompt=None):
|
288 |
"""Generate Audio from tokens"""
|
289 |
assert gen_tokens.dim() == 3
|
290 |
if self.seperate_tokenizer is not None:
|
@@ -292,7 +279,7 @@ class CodecLM:
|
|
292 |
gen_tokens_vocal = gen_tokens[:, [1], :]
|
293 |
gen_tokens_bgm = gen_tokens[:, [2], :]
|
294 |
# gen_audio_song = self.audiotokenizer.decode(gen_tokens_song, prompt)
|
295 |
-
gen_audio_seperate = self.seperate_tokenizer.decode([gen_tokens_vocal, gen_tokens_bgm], vocal_prompt, bgm_prompt)
|
296 |
return gen_audio_seperate
|
297 |
else:
|
298 |
gen_audio = self.audiotokenizer.decode(gen_tokens, prompt)
|
|
|
36 |
max_duration: tp.Optional[float] = None, seperate_tokenizer: AudioTokenizer = None):
|
37 |
self.name = name
|
38 |
self.audiotokenizer = audiotokenizer
|
39 |
+
if self.audiotokenizer:
|
40 |
+
self.frame_rate = self.audiotokenizer.frame_rate
|
41 |
+
else:
|
42 |
+
self.frame_rate = 25
|
43 |
self.lm = lm
|
44 |
self.seperate_tokenizer = seperate_tokenizer
|
45 |
# import pdb; pdb.set_trace()
|
|
|
51 |
assert max_duration is not None
|
52 |
|
53 |
self.max_duration: float = max_duration
|
54 |
+
self.device = torch.device("cuda")
|
55 |
self.generation_params: dict = {}
|
56 |
# self.set_generation_params(duration=15) # 15 seconds by default
|
57 |
self.set_generation_params(duration=15, extend_stride=self.max_duration // 2)
|
|
|
61 |
else:
|
62 |
self.autocast = TorchAutocast(enabled=False)
|
63 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
def set_generation_params(self, use_sampling: bool = True, top_k: int = 250,
|
65 |
top_p: float = 0.0, temperature: float = 1.0,
|
66 |
duration: float = 30.0, cfg_coef: float = 3.0,
|
|
|
172 |
assert len(lyrics) == 1
|
173 |
texts = [lyric for lyric in lyrics]
|
174 |
audio_qt_embs = []
|
175 |
+
target_melody_token_len = self.lm.cfg.prompt_len * self.frame_rate
|
176 |
# import pdb; pdb.set_trace()
|
177 |
if melody_wavs is None:
|
178 |
melody_tokens = torch.full((1,1,target_melody_token_len), 16385, device=self.device).long()
|
|
|
194 |
melody_tokens = melody_tokens[...,:target_melody_token_len]
|
195 |
elif melody_tokens.shape[-1] < target_melody_token_len:
|
196 |
melody_tokens = torch.cat([melody_tokens, torch.full((1,1,target_melody_token_len - melody_tokens.shape[-1]), 16385, device=self.device).long()], dim=-1)
|
197 |
+
|
198 |
+
if bgm_wavs is None:
|
199 |
+
assert vocal_wavs is None, "vocal_wavs is not None when bgm_wavs is None"
|
200 |
+
bgm_tokens = torch.full((1,1,target_melody_token_len), 16385, device=self.device).long()
|
201 |
+
vocal_tokens = torch.full((1,1,target_melody_token_len), 16385, device=self.device).long()
|
202 |
+
else:
|
203 |
+
assert vocal_wavs is not None, "vocal_wavs is None when bgm_wavs is not None"
|
204 |
+
if type(vocal_wavs) == list:
|
205 |
+
vocal_wavs = torch.stack(vocal_wavs, dim=0)
|
206 |
+
if type(bgm_wavs) == list:
|
207 |
+
bgm_wavs = torch.stack(bgm_wavs, dim=0)
|
208 |
+
vocal_wavs = vocal_wavs.to(self.device)
|
209 |
+
bgm_wavs = bgm_wavs.to(self.device)
|
210 |
+
if melody_is_wav:
|
211 |
+
vocal_tokens, bgm_tokens = self.seperate_tokenizer.encode(vocal_wavs, bgm_wavs)
|
212 |
else:
|
213 |
+
vocal_tokens = vocal_wavs
|
214 |
+
bgm_tokens = bgm_wavs
|
215 |
+
assert len(vocal_tokens.shape) == len(bgm_tokens.shape) == 3, \
|
216 |
+
f"vocal and bgm tokens should have a shape [B, C, T]! " \
|
217 |
+
f"got vocal len={vocal_tokens.shape}, and bgm len={bgm_tokens.shape}"
|
218 |
+
assert vocal_tokens.shape[-1] == bgm_tokens.shape[-1], \
|
219 |
+
f"vocal and bgm tokens should have the same length! " \
|
220 |
+
f"got vocal len={vocal_tokens.shape[-1]}, and bgm len={bgm_tokens.shape[-1]}"
|
221 |
+
if bgm_tokens.shape[-1] > target_melody_token_len:
|
222 |
+
bgm_tokens = bgm_tokens[...,:target_melody_token_len]
|
223 |
+
elif bgm_tokens.shape[-1] < target_melody_token_len:
|
224 |
+
bgm_tokens = torch.cat([bgm_tokens, torch.full((1,1,target_melody_token_len - bgm_tokens.shape[-1]), 16385, device=self.device).long()], dim=-1)
|
225 |
+
if vocal_tokens.shape[-1] > target_melody_token_len:
|
226 |
+
vocal_tokens = vocal_tokens[...,:target_melody_token_len]
|
227 |
+
elif vocal_tokens.shape[-1] < target_melody_token_len:
|
228 |
+
vocal_tokens = torch.cat([vocal_tokens, torch.full((1,1,target_melody_token_len - vocal_tokens.shape[-1]), 16385, device=self.device).long()], dim=-1)
|
229 |
+
melody_tokens = torch.cat([melody_tokens, vocal_tokens, bgm_tokens], dim=1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
230 |
assert melody_tokens.shape[-1] == target_melody_token_len
|
231 |
audio_qt_embs = melody_tokens.long()
|
232 |
return texts, audio_qt_embs
|
|
|
271 |
return gen_tokens
|
272 |
|
273 |
@torch.no_grad()
|
274 |
+
def generate_audio(self, gen_tokens: torch.Tensor, prompt=None, vocal_prompt=None, bgm_prompt=None, chunked=False):
|
275 |
"""Generate Audio from tokens"""
|
276 |
assert gen_tokens.dim() == 3
|
277 |
if self.seperate_tokenizer is not None:
|
|
|
279 |
gen_tokens_vocal = gen_tokens[:, [1], :]
|
280 |
gen_tokens_bgm = gen_tokens[:, [2], :]
|
281 |
# gen_audio_song = self.audiotokenizer.decode(gen_tokens_song, prompt)
|
282 |
+
gen_audio_seperate = self.seperate_tokenizer.decode([gen_tokens_vocal, gen_tokens_bgm], vocal_prompt, bgm_prompt, chunked=chunked)
|
283 |
return gen_audio_seperate
|
284 |
else:
|
285 |
gen_audio = self.audiotokenizer.decode(gen_tokens, prompt)
|
codeclm/tokenizer/Flow1dVAE/generate_septoken.py
CHANGED
@@ -173,7 +173,7 @@ class Tango:
|
|
173 |
return codes_vocal, codes_bgm
|
174 |
|
175 |
@torch.no_grad()
|
176 |
-
def code2sound(self, codes, prompt_vocal=None, prompt_bgm=None, duration=40, guidance_scale=1.5, num_steps=20, disable_progress=False):
|
177 |
codes_vocal,codes_bgm = codes
|
178 |
codes_vocal = codes_vocal.to(self.device)
|
179 |
codes_bgm = codes_bgm.to(self.device)
|
@@ -268,11 +268,12 @@ class Tango:
|
|
268 |
min_samples = int(min_samples * self.sample_rate // 1000 * 40)
|
269 |
hop_samples = int(hop_samples * self.sample_rate // 1000 * 40)
|
270 |
ovlp_samples = min_samples - hop_samples
|
|
|
271 |
with torch.no_grad():
|
272 |
output = None
|
273 |
for i in range(len(latent_list)):
|
274 |
latent = latent_list[i]
|
275 |
-
cur_output = self.vae.decode_audio(latent)[0].detach().cpu()
|
276 |
|
277 |
if output is None:
|
278 |
output = cur_output
|
|
|
173 |
return codes_vocal, codes_bgm
|
174 |
|
175 |
@torch.no_grad()
|
176 |
+
def code2sound(self, codes, prompt_vocal=None, prompt_bgm=None, duration=40, guidance_scale=1.5, num_steps=20, disable_progress=False, chunked=False):
|
177 |
codes_vocal,codes_bgm = codes
|
178 |
codes_vocal = codes_vocal.to(self.device)
|
179 |
codes_bgm = codes_bgm.to(self.device)
|
|
|
268 |
min_samples = int(min_samples * self.sample_rate // 1000 * 40)
|
269 |
hop_samples = int(hop_samples * self.sample_rate // 1000 * 40)
|
270 |
ovlp_samples = min_samples - hop_samples
|
271 |
+
torch.cuda.empty_cache()
|
272 |
with torch.no_grad():
|
273 |
output = None
|
274 |
for i in range(len(latent_list)):
|
275 |
latent = latent_list[i]
|
276 |
+
cur_output = self.vae.decode_audio(latent, chunked=chunked)[0].detach().cpu()
|
277 |
|
278 |
if output is None:
|
279 |
output = cur_output
|
codeclm/tokenizer/Flow1dVAE/libs/rvq/core_vq.py
CHANGED
@@ -1,366 +1,366 @@
|
|
1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
-
# All rights reserved.
|
3 |
-
#
|
4 |
-
# This source code is licensed under the license found in the
|
5 |
-
# LICENSE file in the root directory of this source tree.
|
6 |
-
#
|
7 |
-
# This implementation is inspired from
|
8 |
-
# https://github.com/lucidrains/vector-quantize-pytorch
|
9 |
-
# which is released under MIT License. Hereafter, the original license:
|
10 |
-
# MIT License
|
11 |
-
#
|
12 |
-
# Copyright (c) 2020 Phil Wang
|
13 |
-
#
|
14 |
-
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
15 |
-
# of this software and associated documentation files (the "Software"), to deal
|
16 |
-
# in the Software without restriction, including without limitation the rights
|
17 |
-
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
18 |
-
# copies of the Software, and to permit persons to whom the Software is
|
19 |
-
# furnished to do so, subject to the following conditions:
|
20 |
-
#
|
21 |
-
# The above copyright notice and this permission notice shall be included in all
|
22 |
-
# copies or substantial portions of the Software.
|
23 |
-
#
|
24 |
-
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
25 |
-
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
26 |
-
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
27 |
-
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
28 |
-
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
29 |
-
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
30 |
-
# SOFTWARE.
|
31 |
-
|
32 |
-
"""Core vector quantization implementation."""
|
33 |
-
|
34 |
-
import typing as tp
|
35 |
-
|
36 |
-
from einops import rearrange, repeat
|
37 |
-
import torch
|
38 |
-
from torch import nn
|
39 |
-
import torch.nn.functional as F
|
40 |
-
|
41 |
-
# from .. import distrib
|
42 |
-
|
43 |
-
|
44 |
-
def default(val: tp.Any, d: tp.Any) -> tp.Any:
|
45 |
-
return val if val is not None else d
|
46 |
-
|
47 |
-
|
48 |
-
def ema_inplace(moving_avg, new, decay: float):
|
49 |
-
moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
|
50 |
-
|
51 |
-
|
52 |
-
def laplace_smoothing(x, n_categories: int, epsilon: float = 1e-5):
|
53 |
-
return (x + epsilon) / (x.sum() + n_categories * epsilon)
|
54 |
-
|
55 |
-
|
56 |
-
def uniform_init(*shape: int):
|
57 |
-
t = torch.empty(shape)
|
58 |
-
nn.init.kaiming_uniform_(t)
|
59 |
-
return t
|
60 |
-
|
61 |
-
|
62 |
-
def sample_vectors(samples, num: int):
|
63 |
-
num_samples, device = samples.shape[0], samples.device
|
64 |
-
|
65 |
-
if num_samples >= num:
|
66 |
-
indices = torch.randperm(num_samples, device=device)[:num]
|
67 |
-
else:
|
68 |
-
indices = torch.randint(0, num_samples, (num,), device=device)
|
69 |
-
|
70 |
-
return samples[indices]
|
71 |
-
|
72 |
-
|
73 |
-
def kmeans(samples, num_clusters: int, num_iters: int = 10):
|
74 |
-
dim, dtype = samples.shape[-1], samples.dtype
|
75 |
-
|
76 |
-
means = sample_vectors(samples, num_clusters)
|
77 |
-
|
78 |
-
for _ in range(num_iters):
|
79 |
-
diffs = rearrange(samples, "n d -> n () d") - rearrange(
|
80 |
-
means, "c d -> () c d"
|
81 |
-
)
|
82 |
-
dists = -(diffs ** 2).sum(dim=-1)
|
83 |
-
|
84 |
-
buckets = dists.max(dim=-1).indices
|
85 |
-
bins = torch.bincount(buckets, minlength=num_clusters)
|
86 |
-
zero_mask = bins == 0
|
87 |
-
bins_min_clamped = bins.masked_fill(zero_mask, 1)
|
88 |
-
|
89 |
-
new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype)
|
90 |
-
new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=dim), samples)
|
91 |
-
new_means = new_means / bins_min_clamped[..., None]
|
92 |
-
|
93 |
-
means = torch.where(zero_mask[..., None], means, new_means)
|
94 |
-
|
95 |
-
return means, bins
|
96 |
-
|
97 |
-
|
98 |
-
class EuclideanCodebook(nn.Module):
|
99 |
-
"""Codebook with Euclidean distance.
|
100 |
-
Args:
|
101 |
-
dim (int): Dimension.
|
102 |
-
codebook_size (int): Codebook size.
|
103 |
-
kmeans_init (bool): Whether to use k-means to initialize the codebooks.
|
104 |
-
If set to true, run the k-means algorithm on the first training batch and use
|
105 |
-
the learned centroids as initialization.
|
106 |
-
kmeans_iters (int): Number of iterations used for k-means algorithm at initialization.
|
107 |
-
decay (float): Decay for exponential moving average over the codebooks.
|
108 |
-
epsilon (float): Epsilon value for numerical stability.
|
109 |
-
threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
|
110 |
-
that have an exponential moving average cluster size less than the specified threshold with
|
111 |
-
randomly selected vector from the current batch.
|
112 |
-
"""
|
113 |
-
def __init__(
|
114 |
-
self,
|
115 |
-
dim: int,
|
116 |
-
codebook_size: int,
|
117 |
-
kmeans_init: int = False,
|
118 |
-
kmeans_iters: int = 10,
|
119 |
-
decay: float = 0.99,
|
120 |
-
epsilon: float = 1e-5,
|
121 |
-
threshold_ema_dead_code: int = 2,
|
122 |
-
):
|
123 |
-
super().__init__()
|
124 |
-
self.decay = decay
|
125 |
-
init_fn: tp.Union[tp.Callable[..., torch.Tensor], tp.Any] = uniform_init if not kmeans_init else torch.zeros
|
126 |
-
embed = init_fn(codebook_size, dim)
|
127 |
-
|
128 |
-
self.codebook_size = codebook_size
|
129 |
-
|
130 |
-
self.kmeans_iters = kmeans_iters
|
131 |
-
self.epsilon = epsilon
|
132 |
-
self.threshold_ema_dead_code = threshold_ema_dead_code
|
133 |
-
|
134 |
-
self.register_buffer("inited", torch.Tensor([not kmeans_init]))
|
135 |
-
self.register_buffer("cluster_size", torch.zeros(codebook_size))
|
136 |
-
self.register_buffer("embed", embed)
|
137 |
-
self.register_buffer("embed_avg", embed.clone())
|
138 |
-
|
139 |
-
self.runed_steps = 0
|
140 |
-
self.stop_steps = 50_000
|
141 |
-
|
142 |
-
@torch.jit.ignore
|
143 |
-
def init_embed_(self, data):
|
144 |
-
if self.inited:
|
145 |
-
return
|
146 |
-
|
147 |
-
embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters)
|
148 |
-
self.embed.data.copy_(embed)
|
149 |
-
self.embed_avg.data.copy_(embed.clone())
|
150 |
-
self.cluster_size.data.copy_(cluster_size)
|
151 |
-
self.inited.data.copy_(torch.Tensor([True]))
|
152 |
-
# Make sure all buffers across workers are in sync after initialization
|
153 |
-
distrib.broadcast_tensors(self.buffers())
|
154 |
-
|
155 |
-
def replace_(self, samples, mask):
|
156 |
-
modified_codebook = torch.where(
|
157 |
-
mask[..., None], sample_vectors(samples, self.codebook_size), self.embed
|
158 |
-
)
|
159 |
-
self.embed.data.copy_(modified_codebook)
|
160 |
-
|
161 |
-
def expire_codes_(self, batch_samples):
|
162 |
-
if self.threshold_ema_dead_code == 0:
|
163 |
-
return
|
164 |
-
|
165 |
-
expired_codes = self.cluster_size < self.threshold_ema_dead_code
|
166 |
-
if not torch.any(expired_codes):
|
167 |
-
return
|
168 |
-
|
169 |
-
batch_samples = rearrange(batch_samples, "... d -> (...) d")
|
170 |
-
self.replace_(batch_samples, mask=expired_codes)
|
171 |
-
# distrib.broadcast_tensors(self.buffers())
|
172 |
-
|
173 |
-
def preprocess(self, x):
|
174 |
-
x = rearrange(x, "... d -> (...) d")
|
175 |
-
return x
|
176 |
-
|
177 |
-
def quantize(self, x):
|
178 |
-
embed = self.embed.t()
|
179 |
-
dist = -(
|
180 |
-
x.pow(2).sum(1, keepdim=True)
|
181 |
-
- 2 * x @ embed
|
182 |
-
+ embed.pow(2).sum(0, keepdim=True)
|
183 |
-
)
|
184 |
-
embed_ind = dist.max(dim=-1).indices
|
185 |
-
return embed_ind
|
186 |
-
|
187 |
-
def postprocess_emb(self, embed_ind, shape):
|
188 |
-
return embed_ind.view(*shape[:-1])
|
189 |
-
|
190 |
-
def dequantize(self, embed_ind):
|
191 |
-
quantize = F.embedding(embed_ind, self.embed)
|
192 |
-
return quantize
|
193 |
-
|
194 |
-
def encode(self, x):
|
195 |
-
shape = x.shape
|
196 |
-
# pre-process
|
197 |
-
x = self.preprocess(x)
|
198 |
-
# quantize
|
199 |
-
embed_ind = self.quantize(x)
|
200 |
-
# post-process
|
201 |
-
embed_ind = self.postprocess_emb(embed_ind, shape)
|
202 |
-
return embed_ind
|
203 |
-
|
204 |
-
def decode(self, embed_ind):
|
205 |
-
quantize = self.dequantize(embed_ind)
|
206 |
-
return quantize
|
207 |
-
|
208 |
-
def forward(self, x):
|
209 |
-
shape, dtype = x.shape, x.dtype
|
210 |
-
x = self.preprocess(x)
|
211 |
-
# self.init_embed_(x)
|
212 |
-
|
213 |
-
embed_ind = self.quantize(x)
|
214 |
-
embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype)
|
215 |
-
embed_ind = self.postprocess_emb(embed_ind, shape)
|
216 |
-
quantize = self.dequantize(embed_ind)
|
217 |
-
self.runed_steps += 1
|
218 |
-
|
219 |
-
if self.training and self.runed_steps < self.stop_steps:
|
220 |
-
# We do the expiry of code at that point as buffers are in sync
|
221 |
-
# and all the workers will take the same decision.
|
222 |
-
self.expire_codes_(x)
|
223 |
-
ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay)
|
224 |
-
embed_sum = x.t() @ embed_onehot
|
225 |
-
ema_inplace(self.embed_avg, embed_sum.t(), self.decay)
|
226 |
-
cluster_size = (
|
227 |
-
laplace_smoothing(self.cluster_size, self.codebook_size, self.epsilon)
|
228 |
-
* self.cluster_size.sum()
|
229 |
-
)
|
230 |
-
embed_normalized = self.embed_avg / cluster_size.unsqueeze(1)
|
231 |
-
self.embed.data.copy_(embed_normalized)
|
232 |
-
|
233 |
-
return quantize, embed_ind
|
234 |
-
|
235 |
-
|
236 |
-
class VectorQuantization(nn.Module):
|
237 |
-
"""Vector quantization implementation.
|
238 |
-
Currently supports only euclidean distance.
|
239 |
-
Args:
|
240 |
-
dim (int): Dimension
|
241 |
-
codebook_size (int): Codebook size
|
242 |
-
codebook_dim (int): Codebook dimension. If not defined, uses the specified dimension in dim.
|
243 |
-
decay (float): Decay for exponential moving average over the codebooks.
|
244 |
-
epsilon (float): Epsilon value for numerical stability.
|
245 |
-
kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
|
246 |
-
kmeans_iters (int): Number of iterations used for kmeans initialization.
|
247 |
-
threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
|
248 |
-
that have an exponential moving average cluster size less than the specified threshold with
|
249 |
-
randomly selected vector from the current batch.
|
250 |
-
commitment_weight (float): Weight for commitment loss.
|
251 |
-
"""
|
252 |
-
def __init__(
|
253 |
-
self,
|
254 |
-
dim: int,
|
255 |
-
codebook_size: int,
|
256 |
-
codebook_dim: tp.Optional[int] = None,
|
257 |
-
decay: float = 0.99,
|
258 |
-
epsilon: float = 1e-5,
|
259 |
-
kmeans_init: bool = True,
|
260 |
-
kmeans_iters: int = 50,
|
261 |
-
threshold_ema_dead_code: int = 2,
|
262 |
-
commitment_weight: float = 1.,
|
263 |
-
):
|
264 |
-
super().__init__()
|
265 |
-
_codebook_dim: int = default(codebook_dim, dim)
|
266 |
-
|
267 |
-
requires_projection = _codebook_dim != dim
|
268 |
-
self.project_in = (nn.Linear(dim, _codebook_dim) if requires_projection else nn.Identity())
|
269 |
-
self.project_out = (nn.Linear(_codebook_dim, dim) if requires_projection else nn.Identity())
|
270 |
-
|
271 |
-
self.epsilon = epsilon
|
272 |
-
self.commitment_weight = commitment_weight
|
273 |
-
|
274 |
-
self._codebook = EuclideanCodebook(dim=_codebook_dim, codebook_size=codebook_size,
|
275 |
-
kmeans_init=kmeans_init, kmeans_iters=kmeans_iters,
|
276 |
-
decay=decay, epsilon=epsilon,
|
277 |
-
threshold_ema_dead_code=threshold_ema_dead_code)
|
278 |
-
self.codebook_size = codebook_size
|
279 |
-
|
280 |
-
@property
|
281 |
-
def codebook(self):
|
282 |
-
return self._codebook.embed
|
283 |
-
|
284 |
-
def encode(self, x):
|
285 |
-
x = rearrange(x, "b d n -> b n d")
|
286 |
-
x = self.project_in(x)
|
287 |
-
embed_in = self._codebook.encode(x)
|
288 |
-
return embed_in
|
289 |
-
|
290 |
-
def decode(self, embed_ind):
|
291 |
-
quantize = self._codebook.decode(embed_ind)
|
292 |
-
quantize = self.project_out(quantize)
|
293 |
-
quantize = rearrange(quantize, "b n d -> b d n")
|
294 |
-
return quantize
|
295 |
-
|
296 |
-
def forward(self, x, do_debug=False):
|
297 |
-
device = x.device
|
298 |
-
x = rearrange(x, "b d n -> b n d")
|
299 |
-
x = self.project_in(x)
|
300 |
-
|
301 |
-
quantize, embed_ind = self._codebook(x)
|
302 |
-
|
303 |
-
if self.training:
|
304 |
-
quantize = x + (quantize - x).detach()
|
305 |
-
|
306 |
-
loss = torch.tensor([0.0], device=device, requires_grad=self.training)
|
307 |
-
|
308 |
-
if self.training:
|
309 |
-
if self.commitment_weight > 0:
|
310 |
-
commit_loss = F.mse_loss(quantize.detach(), x)
|
311 |
-
loss = loss + commit_loss * self.commitment_weight
|
312 |
-
quantize = self.project_out(quantize)
|
313 |
-
quantize = rearrange(quantize, "b n d -> b d n")
|
314 |
-
return quantize, embed_ind, loss
|
315 |
-
|
316 |
-
|
317 |
-
class ResidualVectorQuantization(nn.Module):
|
318 |
-
"""Residual vector quantization implementation.
|
319 |
-
Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf
|
320 |
-
"""
|
321 |
-
def __init__(self, *, num_quantizers, **kwargs):
|
322 |
-
super().__init__()
|
323 |
-
self.layers = nn.ModuleList(
|
324 |
-
[VectorQuantization(**kwargs) for _ in range(num_quantizers)]
|
325 |
-
)
|
326 |
-
|
327 |
-
def forward(self, x, n_q: tp.Optional[int] = None):
|
328 |
-
quantized_out = 0.0
|
329 |
-
residual = x
|
330 |
-
|
331 |
-
all_losses = []
|
332 |
-
all_indices = []
|
333 |
-
|
334 |
-
n_q = n_q or len(self.layers)
|
335 |
-
|
336 |
-
for layerinx, layer in enumerate(self.layers[:n_q]):
|
337 |
-
print("Layer {} Used ratio {:.1f}".format(layerinx, (layer._codebook.cluster_size > 1.0).sum() / layer._codebook.cluster_size.shape[0] * 100.))
|
338 |
-
quantized, indices, loss = layer(residual)
|
339 |
-
residual = residual - quantized
|
340 |
-
quantized_out = quantized_out + quantized
|
341 |
-
|
342 |
-
all_indices.append(indices)
|
343 |
-
all_losses.append(loss)
|
344 |
-
|
345 |
-
out_losses, out_indices = map(torch.stack, (all_losses, all_indices))
|
346 |
-
return quantized_out, out_indices, out_losses
|
347 |
-
|
348 |
-
def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None) -> torch.Tensor:
|
349 |
-
residual = x
|
350 |
-
all_indices = []
|
351 |
-
n_q = n_q or len(self.layers)
|
352 |
-
for layer in self.layers[:n_q]:
|
353 |
-
indices = layer.encode(residual)
|
354 |
-
quantized = layer.decode(indices)
|
355 |
-
residual = residual - quantized
|
356 |
-
all_indices.append(indices)
|
357 |
-
out_indices = torch.stack(all_indices)
|
358 |
-
return out_indices
|
359 |
-
|
360 |
-
def decode(self, q_indices: torch.Tensor) -> torch.Tensor:
|
361 |
-
quantized_out = torch.tensor(0.0, device=q_indices.device)
|
362 |
-
for i, indices in enumerate(q_indices):
|
363 |
-
layer = self.layers[i]
|
364 |
-
quantized = layer.decode(indices)
|
365 |
-
quantized_out = quantized_out + quantized
|
366 |
-
return quantized_out
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
#
|
7 |
+
# This implementation is inspired from
|
8 |
+
# https://github.com/lucidrains/vector-quantize-pytorch
|
9 |
+
# which is released under MIT License. Hereafter, the original license:
|
10 |
+
# MIT License
|
11 |
+
#
|
12 |
+
# Copyright (c) 2020 Phil Wang
|
13 |
+
#
|
14 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
15 |
+
# of this software and associated documentation files (the "Software"), to deal
|
16 |
+
# in the Software without restriction, including without limitation the rights
|
17 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
18 |
+
# copies of the Software, and to permit persons to whom the Software is
|
19 |
+
# furnished to do so, subject to the following conditions:
|
20 |
+
#
|
21 |
+
# The above copyright notice and this permission notice shall be included in all
|
22 |
+
# copies or substantial portions of the Software.
|
23 |
+
#
|
24 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
25 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
26 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
27 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
28 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
29 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
30 |
+
# SOFTWARE.
|
31 |
+
|
32 |
+
"""Core vector quantization implementation."""
|
33 |
+
|
34 |
+
import typing as tp
|
35 |
+
|
36 |
+
from einops import rearrange, repeat
|
37 |
+
import torch
|
38 |
+
from torch import nn
|
39 |
+
import torch.nn.functional as F
|
40 |
+
|
41 |
+
# from .. import distrib
|
42 |
+
|
43 |
+
|
44 |
+
def default(val: tp.Any, d: tp.Any) -> tp.Any:
|
45 |
+
return val if val is not None else d
|
46 |
+
|
47 |
+
|
48 |
+
def ema_inplace(moving_avg, new, decay: float):
|
49 |
+
moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
|
50 |
+
|
51 |
+
|
52 |
+
def laplace_smoothing(x, n_categories: int, epsilon: float = 1e-5):
|
53 |
+
return (x + epsilon) / (x.sum() + n_categories * epsilon)
|
54 |
+
|
55 |
+
|
56 |
+
def uniform_init(*shape: int):
|
57 |
+
t = torch.empty(shape)
|
58 |
+
nn.init.kaiming_uniform_(t)
|
59 |
+
return t
|
60 |
+
|
61 |
+
|
62 |
+
def sample_vectors(samples, num: int):
|
63 |
+
num_samples, device = samples.shape[0], samples.device
|
64 |
+
|
65 |
+
if num_samples >= num:
|
66 |
+
indices = torch.randperm(num_samples, device=device)[:num]
|
67 |
+
else:
|
68 |
+
indices = torch.randint(0, num_samples, (num,), device=device)
|
69 |
+
|
70 |
+
return samples[indices]
|
71 |
+
|
72 |
+
|
73 |
+
def kmeans(samples, num_clusters: int, num_iters: int = 10):
|
74 |
+
dim, dtype = samples.shape[-1], samples.dtype
|
75 |
+
|
76 |
+
means = sample_vectors(samples, num_clusters)
|
77 |
+
|
78 |
+
for _ in range(num_iters):
|
79 |
+
diffs = rearrange(samples, "n d -> n () d") - rearrange(
|
80 |
+
means, "c d -> () c d"
|
81 |
+
)
|
82 |
+
dists = -(diffs ** 2).sum(dim=-1)
|
83 |
+
|
84 |
+
buckets = dists.max(dim=-1).indices
|
85 |
+
bins = torch.bincount(buckets, minlength=num_clusters)
|
86 |
+
zero_mask = bins == 0
|
87 |
+
bins_min_clamped = bins.masked_fill(zero_mask, 1)
|
88 |
+
|
89 |
+
new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype)
|
90 |
+
new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=dim), samples)
|
91 |
+
new_means = new_means / bins_min_clamped[..., None]
|
92 |
+
|
93 |
+
means = torch.where(zero_mask[..., None], means, new_means)
|
94 |
+
|
95 |
+
return means, bins
|
96 |
+
|
97 |
+
|
98 |
+
class EuclideanCodebook(nn.Module):
|
99 |
+
"""Codebook with Euclidean distance.
|
100 |
+
Args:
|
101 |
+
dim (int): Dimension.
|
102 |
+
codebook_size (int): Codebook size.
|
103 |
+
kmeans_init (bool): Whether to use k-means to initialize the codebooks.
|
104 |
+
If set to true, run the k-means algorithm on the first training batch and use
|
105 |
+
the learned centroids as initialization.
|
106 |
+
kmeans_iters (int): Number of iterations used for k-means algorithm at initialization.
|
107 |
+
decay (float): Decay for exponential moving average over the codebooks.
|
108 |
+
epsilon (float): Epsilon value for numerical stability.
|
109 |
+
threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
|
110 |
+
that have an exponential moving average cluster size less than the specified threshold with
|
111 |
+
randomly selected vector from the current batch.
|
112 |
+
"""
|
113 |
+
def __init__(
|
114 |
+
self,
|
115 |
+
dim: int,
|
116 |
+
codebook_size: int,
|
117 |
+
kmeans_init: int = False,
|
118 |
+
kmeans_iters: int = 10,
|
119 |
+
decay: float = 0.99,
|
120 |
+
epsilon: float = 1e-5,
|
121 |
+
threshold_ema_dead_code: int = 2,
|
122 |
+
):
|
123 |
+
super().__init__()
|
124 |
+
self.decay = decay
|
125 |
+
init_fn: tp.Union[tp.Callable[..., torch.Tensor], tp.Any] = uniform_init if not kmeans_init else torch.zeros
|
126 |
+
embed = init_fn(codebook_size, dim)
|
127 |
+
|
128 |
+
self.codebook_size = codebook_size
|
129 |
+
|
130 |
+
self.kmeans_iters = kmeans_iters
|
131 |
+
self.epsilon = epsilon
|
132 |
+
self.threshold_ema_dead_code = threshold_ema_dead_code
|
133 |
+
|
134 |
+
self.register_buffer("inited", torch.Tensor([not kmeans_init]))
|
135 |
+
self.register_buffer("cluster_size", torch.zeros(codebook_size))
|
136 |
+
self.register_buffer("embed", embed)
|
137 |
+
self.register_buffer("embed_avg", embed.clone())
|
138 |
+
|
139 |
+
self.runed_steps = 0
|
140 |
+
self.stop_steps = 50_000
|
141 |
+
|
142 |
+
@torch.jit.ignore
|
143 |
+
def init_embed_(self, data):
|
144 |
+
if self.inited:
|
145 |
+
return
|
146 |
+
|
147 |
+
embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters)
|
148 |
+
self.embed.data.copy_(embed)
|
149 |
+
self.embed_avg.data.copy_(embed.clone())
|
150 |
+
self.cluster_size.data.copy_(cluster_size)
|
151 |
+
self.inited.data.copy_(torch.Tensor([True]))
|
152 |
+
# Make sure all buffers across workers are in sync after initialization
|
153 |
+
distrib.broadcast_tensors(self.buffers())
|
154 |
+
|
155 |
+
def replace_(self, samples, mask):
|
156 |
+
modified_codebook = torch.where(
|
157 |
+
mask[..., None], sample_vectors(samples, self.codebook_size), self.embed
|
158 |
+
)
|
159 |
+
self.embed.data.copy_(modified_codebook)
|
160 |
+
|
161 |
+
def expire_codes_(self, batch_samples):
|
162 |
+
if self.threshold_ema_dead_code == 0:
|
163 |
+
return
|
164 |
+
|
165 |
+
expired_codes = self.cluster_size < self.threshold_ema_dead_code
|
166 |
+
if not torch.any(expired_codes):
|
167 |
+
return
|
168 |
+
|
169 |
+
batch_samples = rearrange(batch_samples, "... d -> (...) d")
|
170 |
+
self.replace_(batch_samples, mask=expired_codes)
|
171 |
+
# distrib.broadcast_tensors(self.buffers())
|
172 |
+
|
173 |
+
def preprocess(self, x):
|
174 |
+
x = rearrange(x, "... d -> (...) d")
|
175 |
+
return x
|
176 |
+
|
177 |
+
def quantize(self, x):
|
178 |
+
embed = self.embed.t()
|
179 |
+
dist = -(
|
180 |
+
x.pow(2).sum(1, keepdim=True)
|
181 |
+
- 2 * x @ embed
|
182 |
+
+ embed.pow(2).sum(0, keepdim=True)
|
183 |
+
)
|
184 |
+
embed_ind = dist.max(dim=-1).indices
|
185 |
+
return embed_ind
|
186 |
+
|
187 |
+
def postprocess_emb(self, embed_ind, shape):
|
188 |
+
return embed_ind.view(*shape[:-1])
|
189 |
+
|
190 |
+
def dequantize(self, embed_ind):
|
191 |
+
quantize = F.embedding(embed_ind, self.embed)
|
192 |
+
return quantize
|
193 |
+
|
194 |
+
def encode(self, x):
|
195 |
+
shape = x.shape
|
196 |
+
# pre-process
|
197 |
+
x = self.preprocess(x)
|
198 |
+
# quantize
|
199 |
+
embed_ind = self.quantize(x)
|
200 |
+
# post-process
|
201 |
+
embed_ind = self.postprocess_emb(embed_ind, shape)
|
202 |
+
return embed_ind
|
203 |
+
|
204 |
+
def decode(self, embed_ind):
|
205 |
+
quantize = self.dequantize(embed_ind)
|
206 |
+
return quantize
|
207 |
+
|
208 |
+
def forward(self, x):
|
209 |
+
shape, dtype = x.shape, x.dtype
|
210 |
+
x = self.preprocess(x)
|
211 |
+
# self.init_embed_(x)
|
212 |
+
|
213 |
+
embed_ind = self.quantize(x)
|
214 |
+
embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype)
|
215 |
+
embed_ind = self.postprocess_emb(embed_ind, shape)
|
216 |
+
quantize = self.dequantize(embed_ind)
|
217 |
+
self.runed_steps += 1
|
218 |
+
|
219 |
+
if self.training and self.runed_steps < self.stop_steps:
|
220 |
+
# We do the expiry of code at that point as buffers are in sync
|
221 |
+
# and all the workers will take the same decision.
|
222 |
+
self.expire_codes_(x)
|
223 |
+
ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay)
|
224 |
+
embed_sum = x.t() @ embed_onehot
|
225 |
+
ema_inplace(self.embed_avg, embed_sum.t(), self.decay)
|
226 |
+
cluster_size = (
|
227 |
+
laplace_smoothing(self.cluster_size, self.codebook_size, self.epsilon)
|
228 |
+
* self.cluster_size.sum()
|
229 |
+
)
|
230 |
+
embed_normalized = self.embed_avg / cluster_size.unsqueeze(1)
|
231 |
+
self.embed.data.copy_(embed_normalized)
|
232 |
+
|
233 |
+
return quantize, embed_ind
|
234 |
+
|
235 |
+
|
236 |
+
class VectorQuantization(nn.Module):
|
237 |
+
"""Vector quantization implementation.
|
238 |
+
Currently supports only euclidean distance.
|
239 |
+
Args:
|
240 |
+
dim (int): Dimension
|
241 |
+
codebook_size (int): Codebook size
|
242 |
+
codebook_dim (int): Codebook dimension. If not defined, uses the specified dimension in dim.
|
243 |
+
decay (float): Decay for exponential moving average over the codebooks.
|
244 |
+
epsilon (float): Epsilon value for numerical stability.
|
245 |
+
kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
|
246 |
+
kmeans_iters (int): Number of iterations used for kmeans initialization.
|
247 |
+
threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
|
248 |
+
that have an exponential moving average cluster size less than the specified threshold with
|
249 |
+
randomly selected vector from the current batch.
|
250 |
+
commitment_weight (float): Weight for commitment loss.
|
251 |
+
"""
|
252 |
+
def __init__(
|
253 |
+
self,
|
254 |
+
dim: int,
|
255 |
+
codebook_size: int,
|
256 |
+
codebook_dim: tp.Optional[int] = None,
|
257 |
+
decay: float = 0.99,
|
258 |
+
epsilon: float = 1e-5,
|
259 |
+
kmeans_init: bool = True,
|
260 |
+
kmeans_iters: int = 50,
|
261 |
+
threshold_ema_dead_code: int = 2,
|
262 |
+
commitment_weight: float = 1.,
|
263 |
+
):
|
264 |
+
super().__init__()
|
265 |
+
_codebook_dim: int = default(codebook_dim, dim)
|
266 |
+
|
267 |
+
requires_projection = _codebook_dim != dim
|
268 |
+
self.project_in = (nn.Linear(dim, _codebook_dim) if requires_projection else nn.Identity())
|
269 |
+
self.project_out = (nn.Linear(_codebook_dim, dim) if requires_projection else nn.Identity())
|
270 |
+
|
271 |
+
self.epsilon = epsilon
|
272 |
+
self.commitment_weight = commitment_weight
|
273 |
+
|
274 |
+
self._codebook = EuclideanCodebook(dim=_codebook_dim, codebook_size=codebook_size,
|
275 |
+
kmeans_init=kmeans_init, kmeans_iters=kmeans_iters,
|
276 |
+
decay=decay, epsilon=epsilon,
|
277 |
+
threshold_ema_dead_code=threshold_ema_dead_code)
|
278 |
+
self.codebook_size = codebook_size
|
279 |
+
|
280 |
+
@property
|
281 |
+
def codebook(self):
|
282 |
+
return self._codebook.embed
|
283 |
+
|
284 |
+
def encode(self, x):
|
285 |
+
x = rearrange(x, "b d n -> b n d")
|
286 |
+
x = self.project_in(x)
|
287 |
+
embed_in = self._codebook.encode(x)
|
288 |
+
return embed_in
|
289 |
+
|
290 |
+
def decode(self, embed_ind):
|
291 |
+
quantize = self._codebook.decode(embed_ind)
|
292 |
+
quantize = self.project_out(quantize)
|
293 |
+
quantize = rearrange(quantize, "b n d -> b d n")
|
294 |
+
return quantize
|
295 |
+
|
296 |
+
def forward(self, x, do_debug=False):
|
297 |
+
device = x.device
|
298 |
+
x = rearrange(x, "b d n -> b n d")
|
299 |
+
x = self.project_in(x)
|
300 |
+
|
301 |
+
quantize, embed_ind = self._codebook(x)
|
302 |
+
|
303 |
+
if self.training:
|
304 |
+
quantize = x + (quantize - x).detach()
|
305 |
+
|
306 |
+
loss = torch.tensor([0.0], device=device, requires_grad=self.training)
|
307 |
+
|
308 |
+
if self.training:
|
309 |
+
if self.commitment_weight > 0:
|
310 |
+
commit_loss = F.mse_loss(quantize.detach(), x)
|
311 |
+
loss = loss + commit_loss * self.commitment_weight
|
312 |
+
quantize = self.project_out(quantize)
|
313 |
+
quantize = rearrange(quantize, "b n d -> b d n")
|
314 |
+
return quantize, embed_ind, loss
|
315 |
+
|
316 |
+
|
317 |
+
class ResidualVectorQuantization(nn.Module):
|
318 |
+
"""Residual vector quantization implementation.
|
319 |
+
Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf
|
320 |
+
"""
|
321 |
+
def __init__(self, *, num_quantizers, **kwargs):
|
322 |
+
super().__init__()
|
323 |
+
self.layers = nn.ModuleList(
|
324 |
+
[VectorQuantization(**kwargs) for _ in range(num_quantizers)]
|
325 |
+
)
|
326 |
+
|
327 |
+
def forward(self, x, n_q: tp.Optional[int] = None):
|
328 |
+
quantized_out = 0.0
|
329 |
+
residual = x
|
330 |
+
|
331 |
+
all_losses = []
|
332 |
+
all_indices = []
|
333 |
+
|
334 |
+
n_q = n_q or len(self.layers)
|
335 |
+
|
336 |
+
for layerinx, layer in enumerate(self.layers[:n_q]):
|
337 |
+
print("Layer {} Used ratio {:.1f}".format(layerinx, (layer._codebook.cluster_size > 1.0).sum() / layer._codebook.cluster_size.shape[0] * 100.))
|
338 |
+
quantized, indices, loss = layer(residual)
|
339 |
+
residual = residual - quantized
|
340 |
+
quantized_out = quantized_out + quantized
|
341 |
+
|
342 |
+
all_indices.append(indices)
|
343 |
+
all_losses.append(loss)
|
344 |
+
|
345 |
+
out_losses, out_indices = map(torch.stack, (all_losses, all_indices))
|
346 |
+
return quantized_out, out_indices, out_losses
|
347 |
+
|
348 |
+
def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None) -> torch.Tensor:
|
349 |
+
residual = x
|
350 |
+
all_indices = []
|
351 |
+
n_q = n_q or len(self.layers)
|
352 |
+
for layer in self.layers[:n_q]:
|
353 |
+
indices = layer.encode(residual)
|
354 |
+
quantized = layer.decode(indices)
|
355 |
+
residual = residual - quantized
|
356 |
+
all_indices.append(indices)
|
357 |
+
out_indices = torch.stack(all_indices)
|
358 |
+
return out_indices
|
359 |
+
|
360 |
+
def decode(self, q_indices: torch.Tensor) -> torch.Tensor:
|
361 |
+
quantized_out = torch.tensor(0.0, device=q_indices.device)
|
362 |
+
for i, indices in enumerate(q_indices):
|
363 |
+
layer = self.layers[i]
|
364 |
+
quantized = layer.decode(indices)
|
365 |
+
quantized_out = quantized_out + quantized
|
366 |
+
return quantized_out
|
codeclm/tokenizer/Flow1dVAE/model_1rvq.py
CHANGED
@@ -1,710 +1,710 @@
|
|
1 |
-
import yaml
|
2 |
-
import random
|
3 |
-
import inspect
|
4 |
-
import numpy as np
|
5 |
-
from tqdm import tqdm
|
6 |
-
import typing as tp
|
7 |
-
from abc import ABC
|
8 |
-
|
9 |
-
import torch
|
10 |
-
import torch.nn as nn
|
11 |
-
import torch.nn.functional as F
|
12 |
-
import torchaudio
|
13 |
-
|
14 |
-
from tools.torch_tools import wav_to_fbank
|
15 |
-
|
16 |
-
from diffusers.utils.torch_utils import randn_tensor
|
17 |
-
from transformers import HubertModel
|
18 |
-
from libs.rvq.descript_quantize3 import ResidualVectorQuantize
|
19 |
-
|
20 |
-
from models_gpt.models.gpt2_rope2_time_new_correct_mask_noncasual_reflow import GPT2Model
|
21 |
-
from models_gpt.models.gpt2_config import GPT2Config
|
22 |
-
|
23 |
-
from torch.cuda.amp import autocast
|
24 |
-
|
25 |
-
|
26 |
-
from our_MERT_BESTRQ.test import load_model
|
27 |
-
|
28 |
-
class HubertModelWithFinalProj(HubertModel):
|
29 |
-
def __init__(self, config):
|
30 |
-
super().__init__(config)
|
31 |
-
|
32 |
-
# The final projection layer is only used for backward compatibility.
|
33 |
-
# Following https://github.com/auspicious3000/contentvec/issues/6
|
34 |
-
# Remove this layer is necessary to achieve the desired outcome.
|
35 |
-
print("hidden_size:",config.hidden_size)
|
36 |
-
print("classifier_proj_size:",config.classifier_proj_size)
|
37 |
-
self.final_proj = nn.Linear(config.hidden_size, config.classifier_proj_size)
|
38 |
-
|
39 |
-
|
40 |
-
class SampleProcessor(torch.nn.Module):
|
41 |
-
def project_sample(self, x: torch.Tensor):
|
42 |
-
"""Project the original sample to the 'space' where the diffusion will happen."""
|
43 |
-
"""Project back from diffusion space to the actual sample space."""
|
44 |
-
return z
|
45 |
-
|
46 |
-
class Feature1DProcessor(SampleProcessor):
|
47 |
-
def __init__(self, dim: int = 100, power_std = 1., \
|
48 |
-
num_samples: int = 100_000, cal_num_frames: int = 600):
|
49 |
-
super().__init__()
|
50 |
-
|
51 |
-
self.num_samples = num_samples
|
52 |
-
self.dim = dim
|
53 |
-
self.power_std = power_std
|
54 |
-
self.cal_num_frames = cal_num_frames
|
55 |
-
self.register_buffer('counts', torch.zeros(1))
|
56 |
-
self.register_buffer('sum_x', torch.zeros(dim))
|
57 |
-
self.register_buffer('sum_x2', torch.zeros(dim))
|
58 |
-
self.register_buffer('sum_target_x2', torch.zeros(dim))
|
59 |
-
self.counts: torch.Tensor
|
60 |
-
self.sum_x: torch.Tensor
|
61 |
-
self.sum_x2: torch.Tensor
|
62 |
-
|
63 |
-
@property
|
64 |
-
def mean(self):
|
65 |
-
mean = self.sum_x / self.counts
|
66 |
-
if(self.counts < 10):
|
67 |
-
mean = torch.zeros_like(mean)
|
68 |
-
return mean
|
69 |
-
|
70 |
-
@property
|
71 |
-
def std(self):
|
72 |
-
std = (self.sum_x2 / self.counts - self.mean**2).clamp(min=0).sqrt()
|
73 |
-
if(self.counts < 10):
|
74 |
-
std = torch.ones_like(std)
|
75 |
-
return std
|
76 |
-
|
77 |
-
@property
|
78 |
-
def target_std(self):
|
79 |
-
return 1
|
80 |
-
|
81 |
-
def project_sample(self, x: torch.Tensor):
|
82 |
-
assert x.dim() == 3
|
83 |
-
if self.counts.item() < self.num_samples:
|
84 |
-
self.counts += len(x)
|
85 |
-
self.sum_x += x[:,:,0:self.cal_num_frames].mean(dim=(2,)).sum(dim=0)
|
86 |
-
self.sum_x2 += x[:,:,0:self.cal_num_frames].pow(2).mean(dim=(2,)).sum(dim=0)
|
87 |
-
rescale = (self.target_std / self.std.clamp(min=1e-12)) ** self.power_std # same output size
|
88 |
-
x = (x - self.mean.view(1, -1, 1)) * rescale.view(1, -1, 1)
|
89 |
-
return x
|
90 |
-
|
91 |
-
def return_sample(self, x: torch.Tensor):
|
92 |
-
assert x.dim() == 3
|
93 |
-
rescale = (self.std / self.target_std) ** self.power_std
|
94 |
-
# print(rescale, self.mean)
|
95 |
-
x = x * rescale.view(1, -1, 1) + self.mean.view(1, -1, 1)
|
96 |
-
return x
|
97 |
-
|
98 |
-
def pad_or_tunc_tolen(prior_text_encoder_hidden_states, prior_text_mask, prior_prompt_embeds, len_size=77):
|
99 |
-
if(prior_text_encoder_hidden_states.shape[1]<len_size):
|
100 |
-
prior_text_encoder_hidden_states = torch.cat([prior_text_encoder_hidden_states, \
|
101 |
-
torch.zeros(prior_text_mask.shape[0], len_size-prior_text_mask.shape[1], \
|
102 |
-
prior_text_encoder_hidden_states.shape[2], device=prior_text_mask.device, \
|
103 |
-
dtype=prior_text_encoder_hidden_states.dtype)],1)
|
104 |
-
prior_text_mask = torch.cat([prior_text_mask, torch.zeros(prior_text_mask.shape[0], len_size-prior_text_mask.shape[1], device=prior_text_mask.device, dtype=prior_text_mask.dtype)],1)
|
105 |
-
else:
|
106 |
-
prior_text_encoder_hidden_states = prior_text_encoder_hidden_states[:,0:len_size]
|
107 |
-
prior_text_mask = prior_text_mask[:,0:len_size]
|
108 |
-
prior_text_encoder_hidden_states = prior_text_encoder_hidden_states.permute(0,2,1).contiguous()
|
109 |
-
return prior_text_encoder_hidden_states, prior_text_mask, prior_prompt_embeds
|
110 |
-
|
111 |
-
class BASECFM(torch.nn.Module, ABC):
|
112 |
-
def __init__(
|
113 |
-
self,
|
114 |
-
estimator,
|
115 |
-
mlp,
|
116 |
-
ssl_layer
|
117 |
-
):
|
118 |
-
super().__init__()
|
119 |
-
self.sigma_min = 1e-4
|
120 |
-
|
121 |
-
self.estimator = estimator
|
122 |
-
self.mlp = mlp
|
123 |
-
self.ssl_layer = ssl_layer
|
124 |
-
|
125 |
-
@torch.inference_mode()
|
126 |
-
def forward(self, mu, n_timesteps, temperature=1.0):
|
127 |
-
"""Forward diffusion
|
128 |
-
|
129 |
-
Args:
|
130 |
-
mu (torch.Tensor): output of encoder
|
131 |
-
shape: (batch_size, n_channels, mel_timesteps, n_feats)
|
132 |
-
n_timesteps (int): number of diffusion steps
|
133 |
-
temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
|
134 |
-
|
135 |
-
Returns:
|
136 |
-
sample: generated mel-spectrogram
|
137 |
-
shape: (batch_size, n_channels, mel_timesteps, n_feats)
|
138 |
-
"""
|
139 |
-
z = torch.randn_like(mu) * temperature
|
140 |
-
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device)
|
141 |
-
return self.solve_euler(z, t_span=t_span)
|
142 |
-
|
143 |
-
def solve_euler(self, x, latent_mask_input,incontext_x, incontext_length, t_span, mu,attention_mask, guidance_scale):
|
144 |
-
"""
|
145 |
-
Fixed euler solver for ODEs.
|
146 |
-
Args:
|
147 |
-
x (torch.Tensor): random noise
|
148 |
-
t_span (torch.Tensor): n_timesteps interpolated
|
149 |
-
shape: (n_timesteps + 1,)
|
150 |
-
mu (torch.Tensor): output of encoder
|
151 |
-
shape: (batch_size, n_channels, mel_timesteps, n_feats)
|
152 |
-
"""
|
153 |
-
t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
|
154 |
-
noise = x.clone()
|
155 |
-
|
156 |
-
# I am storing this because I can later plot it by putting a debugger here and saving it to a file
|
157 |
-
# Or in future might add like a return_all_steps flag
|
158 |
-
sol = []
|
159 |
-
|
160 |
-
for step in tqdm(range(1, len(t_span))):
|
161 |
-
# print("incontext_x.shape:",incontext_x.shape)
|
162 |
-
# print("noise.shape:",noise.shape)
|
163 |
-
# print("t.shape:",t.shape)
|
164 |
-
x[:,0:incontext_length,:] = (1 - (1 - self.sigma_min) * t) * noise[:,0:incontext_length,:] + t * incontext_x[:,0:incontext_length,:]
|
165 |
-
if(guidance_scale > 1.0):
|
166 |
-
|
167 |
-
model_input = torch.cat([ \
|
168 |
-
torch.cat([latent_mask_input, latent_mask_input], 0), \
|
169 |
-
torch.cat([incontext_x, incontext_x], 0), \
|
170 |
-
torch.cat([torch.zeros_like(mu), mu], 0), \
|
171 |
-
torch.cat([x, x], 0), \
|
172 |
-
], 2)
|
173 |
-
timestep=t.unsqueeze(-1).repeat(2)
|
174 |
-
|
175 |
-
dphi_dt = self.estimator(inputs_embeds=model_input, attention_mask=attention_mask,time_step=timestep).last_hidden_state
|
176 |
-
dphi_dt_uncond, dhpi_dt_cond = dphi_dt.chunk(2,0)
|
177 |
-
dphi_dt = dphi_dt_uncond + guidance_scale * (dhpi_dt_cond - dphi_dt_uncond)
|
178 |
-
else:
|
179 |
-
model_input = torch.cat([latent_mask_input, incontext_x, mu, x], 2)
|
180 |
-
timestep=t.unsqueeze(-1)
|
181 |
-
dphi_dt = self.estimator(inputs_embeds=model_input, attention_mask=attention_mask,time_step=timestep).last_hidden_state
|
182 |
-
|
183 |
-
dphi_dt = dphi_dt[: ,:, -x.shape[2]:]
|
184 |
-
# print("dphi_dt.shape:",dphi_dt.shape)
|
185 |
-
# print("x.shape:",x.shape)
|
186 |
-
|
187 |
-
x = x + dt * dphi_dt
|
188 |
-
t = t + dt
|
189 |
-
sol.append(x)
|
190 |
-
if step < len(t_span) - 1:
|
191 |
-
dt = t_span[step + 1] - t
|
192 |
-
|
193 |
-
return sol[-1]
|
194 |
-
|
195 |
-
def projection_loss(self,hidden_proj, bestrq_emb):
|
196 |
-
bsz = hidden_proj.shape[0]
|
197 |
-
|
198 |
-
hidden_proj_normalized = F.normalize(hidden_proj, dim=-1)
|
199 |
-
bestrq_emb_normalized = F.normalize(bestrq_emb, dim=-1)
|
200 |
-
|
201 |
-
proj_loss = -(hidden_proj_normalized * bestrq_emb_normalized).sum(dim=-1)
|
202 |
-
proj_loss = 1+proj_loss.mean()
|
203 |
-
|
204 |
-
return proj_loss
|
205 |
-
|
206 |
-
def compute_loss(self, x1, mu, latent_masks,attention_mask,wav2vec_embeds, validation_mode=False):
|
207 |
-
"""Computes diffusion loss
|
208 |
-
|
209 |
-
Args:
|
210 |
-
x1 (torch.Tensor): Target
|
211 |
-
shape: (batch_size, n_channels, mel_timesteps, n_feats)
|
212 |
-
mu (torch.Tensor): output of encoder
|
213 |
-
shape: (batch_size, n_channels, mel_timesteps, n_feats)
|
214 |
-
|
215 |
-
Returns:
|
216 |
-
loss: conditional flow matching loss
|
217 |
-
y: conditional flow
|
218 |
-
shape: (batch_size, n_channels, mel_timesteps, n_feats)
|
219 |
-
"""
|
220 |
-
b = mu[0].shape[0]
|
221 |
-
len_x = x1.shape[2]
|
222 |
-
# random timestep
|
223 |
-
if(validation_mode):
|
224 |
-
t = torch.ones([b, 1, 1], device=mu[0].device, dtype=mu[0].dtype) * 0.5
|
225 |
-
else:
|
226 |
-
t = torch.rand([b, 1, 1], device=mu[0].device, dtype=mu[0].dtype)
|
227 |
-
# sample noise p(x_0)
|
228 |
-
z = torch.randn_like(x1)
|
229 |
-
|
230 |
-
y = (1 - (1 - self.sigma_min) * t) * z + t * x1
|
231 |
-
u = x1 - (1 - self.sigma_min) * z
|
232 |
-
# print("y.shape:",y.shape)
|
233 |
-
#self.unet(inputs_embeds=model_input, attention_mask=attention_mask,encoder_hidden_states=text_embedding,encoder_attention_mask=txt_attn_mask,time_step=timesteps).last_hidden_state
|
234 |
-
model_input = torch.cat([*mu,y], 2)
|
235 |
-
t=t.squeeze(-1).squeeze(-1)
|
236 |
-
# print("model_input.shape:",model_input.shape)
|
237 |
-
# print("attention_mask.shape:",attention_mask.shape)
|
238 |
-
out = self.estimator(inputs_embeds=model_input, attention_mask=attention_mask,time_step=t,output_hidden_states=True)
|
239 |
-
hidden_layer = out.hidden_states[self.ssl_layer]
|
240 |
-
hidden_proj = self.mlp(hidden_layer)
|
241 |
-
# print("hidden_proj.shape:",hidden_proj.shape)
|
242 |
-
# print("mert_emb.shape:",mert_emb.shape)
|
243 |
-
# exit()
|
244 |
-
|
245 |
-
|
246 |
-
out = out.last_hidden_state
|
247 |
-
|
248 |
-
out=out[:,:,-len_x:]
|
249 |
-
# out=self.proj_out(out)
|
250 |
-
|
251 |
-
weight = (latent_masks > 1.5).unsqueeze(-1).repeat(1, 1, out.shape[-1]).float() + (latent_masks < 0.5).unsqueeze(-1).repeat(1, 1, out.shape[-1]).float() * 0.01
|
252 |
-
# print("out.shape",out.shape)
|
253 |
-
# print("u.shape",u.shape)
|
254 |
-
loss_re = F.mse_loss(out * weight, u * weight, reduction="sum") / weight.sum()
|
255 |
-
# print("hidden_proj.shape:",hidden_proj.shape)
|
256 |
-
# print("wav2vec_embeds.shape:",wav2vec_embeds.shape)
|
257 |
-
loss_cos = self.projection_loss(hidden_proj, wav2vec_embeds)
|
258 |
-
loss = loss_re + loss_cos * 0.5
|
259 |
-
# print("loss_cos:",loss_cos,loss_cos.device)
|
260 |
-
print("loss:",loss,loss.device)
|
261 |
-
# exit()
|
262 |
-
return loss, loss_re, loss_cos
|
263 |
-
|
264 |
-
class PromptCondAudioDiffusion(nn.Module):
|
265 |
-
def __init__(
|
266 |
-
self,
|
267 |
-
num_channels,
|
268 |
-
unet_model_name=None,
|
269 |
-
unet_model_config_path=None,
|
270 |
-
snr_gamma=None,
|
271 |
-
hubert_layer=None,
|
272 |
-
ssl_layer=None,
|
273 |
-
uncondition=True,
|
274 |
-
out_paint=False,
|
275 |
-
):
|
276 |
-
super().__init__()
|
277 |
-
|
278 |
-
assert unet_model_name is not None or unet_model_config_path is not None, "Either UNet pretrain model name or a config file path is required"
|
279 |
-
|
280 |
-
self.unet_model_name = unet_model_name
|
281 |
-
self.unet_model_config_path = unet_model_config_path
|
282 |
-
self.snr_gamma = snr_gamma
|
283 |
-
self.uncondition = uncondition
|
284 |
-
self.num_channels = num_channels
|
285 |
-
self.hubert_layer = hubert_layer
|
286 |
-
self.ssl_layer = ssl_layer
|
287 |
-
|
288 |
-
# https://huggingface.co/docs/diffusers/v0.14.0/en/api/schedulers/overview
|
289 |
-
self.normfeat = Feature1DProcessor(dim=64)
|
290 |
-
|
291 |
-
self.sample_rate = 48000
|
292 |
-
self.num_samples_perseg = self.sample_rate * 20 // 1000
|
293 |
-
self.rsp48toclap = torchaudio.transforms.Resample(48000, 24000)
|
294 |
-
self.rsq48towav2vec = torchaudio.transforms.Resample(48000, 16000)
|
295 |
-
# self.wav2vec = Wav2Vec2BertModel.from_pretrained("facebook/w2v-bert-2.0", trust_remote_code=True)
|
296 |
-
# self.wav2vec_processor = AutoFeatureExtractor.from_pretrained("facebook/w2v-bert-2.0", trust_remote_code=True)
|
297 |
-
self.bestrq = load_model(
|
298 |
-
model_dir='codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq',
|
299 |
-
checkpoint_dir='ckpt/encode-s12k.pt',
|
300 |
-
)
|
301 |
-
self.rsq48tobestrq = torchaudio.transforms.Resample(48000, 24000)
|
302 |
-
self.rsq48tohubert = torchaudio.transforms.Resample(48000, 16000)
|
303 |
-
for v in self.bestrq.parameters():v.requires_grad = False
|
304 |
-
self.rvq_bestrq_emb = ResidualVectorQuantize(input_dim = 1024, n_codebooks = 1, codebook_size = 16_384, codebook_dim = 32, quantizer_dropout = 0.0, stale_tolerance=200)
|
305 |
-
for v in self.rvq_bestrq_emb.parameters():v.requires_grad = False
|
306 |
-
self.hubert = HubertModelWithFinalProj.from_pretrained("ckpt/models--lengyue233--content-vec-best/snapshots/c0b9ba13db21beaa4053faae94c102ebe326fd68")
|
307 |
-
for v in self.hubert.parameters():v.requires_grad = False
|
308 |
-
self.zero_cond_embedding1 = nn.Parameter(torch.randn(32*32,))
|
309 |
-
# self.xvecmodel = XVECModel()
|
310 |
-
config = GPT2Config(n_positions=1000,n_layer=39,n_head=30,n_embd=1200)
|
311 |
-
unet = GPT2Model(config)
|
312 |
-
mlp = nn.Sequential(
|
313 |
-
nn.Linear(1200, 1024),
|
314 |
-
nn.SiLU(),
|
315 |
-
nn.Linear(1024, 1024),
|
316 |
-
nn.SiLU(),
|
317 |
-
nn.Linear(1024, 768)
|
318 |
-
)
|
319 |
-
self.set_from = "random"
|
320 |
-
self.cfm_wrapper = BASECFM(unet, mlp,self.ssl_layer)
|
321 |
-
self.mask_emb = torch.nn.Embedding(3, 48)
|
322 |
-
print("Transformer initialized from pretrain.")
|
323 |
-
torch.cuda.empty_cache()
|
324 |
-
# self.unet.set_attn_processor(AttnProcessor2_0())
|
325 |
-
# self.unet.set_use_memory_efficient_attention_xformers(True)
|
326 |
-
|
327 |
-
# self.start_embedding = nn.Parameter(torch.randn(1,1024))
|
328 |
-
# self.end_embedding = nn.Parameter(torch.randn(1,1024))
|
329 |
-
|
330 |
-
def compute_snr(self, timesteps):
|
331 |
-
"""
|
332 |
-
Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
|
333 |
-
"""
|
334 |
-
alphas_cumprod = self.noise_scheduler.alphas_cumprod
|
335 |
-
sqrt_alphas_cumprod = alphas_cumprod**0.5
|
336 |
-
sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
|
337 |
-
|
338 |
-
# Expand the tensors.
|
339 |
-
# Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026
|
340 |
-
sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
|
341 |
-
while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape):
|
342 |
-
sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
|
343 |
-
alpha = sqrt_alphas_cumprod.expand(timesteps.shape)
|
344 |
-
|
345 |
-
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
|
346 |
-
while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape):
|
347 |
-
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None]
|
348 |
-
sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape)
|
349 |
-
|
350 |
-
# Compute SNR.
|
351 |
-
snr = (alpha / sigma) ** 2
|
352 |
-
return snr
|
353 |
-
|
354 |
-
def preprocess_audio(self, input_audios, threshold=0.9):
|
355 |
-
assert len(input_audios.shape) == 2, input_audios.shape
|
356 |
-
norm_value = torch.ones_like(input_audios[:,0])
|
357 |
-
max_volume = input_audios.abs().max(dim=-1)[0]
|
358 |
-
norm_value[max_volume>threshold] = max_volume[max_volume>threshold] / threshold
|
359 |
-
return input_audios/norm_value.unsqueeze(-1)
|
360 |
-
|
361 |
-
def extract_wav2vec_embeds(self, input_audios,output_len):
|
362 |
-
wav2vec_stride = 2
|
363 |
-
|
364 |
-
wav2vec_embeds = self.hubert(self.rsq48tohubert(input_audios), output_hidden_states=True).hidden_states # 1, 4096, 1024
|
365 |
-
# print(wav2vec_embeds)
|
366 |
-
# print("audio.shape:",input_audios.shape)
|
367 |
-
wav2vec_embeds_last=wav2vec_embeds[self.hubert_layer]
|
368 |
-
# print("wav2vec_embeds_last.shape:",wav2vec_embeds_last.shape)
|
369 |
-
wav2vec_embeds_last=torch.nn.functional.interpolate(wav2vec_embeds_last.permute(0, 2, 1), size=output_len, mode='linear', align_corners=False).permute(0, 2, 1)
|
370 |
-
return wav2vec_embeds_last
|
371 |
-
|
372 |
-
def extract_mert_embeds(self, input_audios):
|
373 |
-
prompt_stride = 3
|
374 |
-
inputs = self.clap_embd_extractor.mulan.audio.processor(self.rsp48toclap(input_audios), sampling_rate=self.clap_embd_extractor.mulan.audio.sr, return_tensors="pt")
|
375 |
-
input_values = inputs['input_values'].squeeze(0).to(input_audios.device, dtype = input_audios.dtype)
|
376 |
-
prompt_embeds = self.clap_embd_extractor.mulan.audio.model(input_values, output_hidden_states=True).hidden_states # batch_size, Time steps, 1024
|
377 |
-
mert_emb= prompt_embeds[-1]
|
378 |
-
mert_emb = torch.nn.functional.interpolate(mert_emb.permute(0, 2, 1), size=500, mode='linear', align_corners=False).permute(0, 2, 1)
|
379 |
-
|
380 |
-
return mert_emb
|
381 |
-
|
382 |
-
def extract_bestrq_embeds(self, input_audio_0,input_audio_1,layer):
|
383 |
-
self.bestrq.eval()
|
384 |
-
# print("audio shape:",input_audio_0.shape)
|
385 |
-
input_wav_mean = (input_audio_0 + input_audio_1) / 2.0
|
386 |
-
# print("input_wav_mean.shape:",input_wav_mean.shape)
|
387 |
-
# input_wav_mean = torch.randn(2,1720320*2).to(input_audio_0.device)
|
388 |
-
input_wav_mean = self.bestrq(self.rsq48tobestrq(input_wav_mean), features_only = True)
|
389 |
-
layer_results = input_wav_mean['layer_results']
|
390 |
-
# print("layer_results.shape:",layer_results[layer].shape)
|
391 |
-
bestrq_emb = layer_results[layer]
|
392 |
-
bestrq_emb = bestrq_emb.permute(0,2,1).contiguous()
|
393 |
-
#[b,t,1024] t=t/960
|
394 |
-
#35.84s->batch,896,1024
|
395 |
-
return bestrq_emb
|
396 |
-
|
397 |
-
|
398 |
-
def extract_spk_embeds(self, input_audios):
|
399 |
-
spk_embeds = self.xvecmodel(self.rsq48towav2vec(input_audios))
|
400 |
-
spk_embeds = self.spk_linear(spk_embeds).reshape(spk_embeds.shape[0], 16, 1, 32)
|
401 |
-
return spk_embeds
|
402 |
-
|
403 |
-
def extract_lyric_feats(self, lyric):
|
404 |
-
with torch.no_grad():
|
405 |
-
try:
|
406 |
-
text_encoder_hidden_states, text_mask, text_prompt_embeds = self.clap_embd_extractor(texts = lyric, return_one=False)
|
407 |
-
except:
|
408 |
-
text_encoder_hidden_states, text_mask, text_prompt_embeds = self.clap_embd_extractor(texts = [""] * len(lyric), return_one=False)
|
409 |
-
text_encoder_hidden_states = text_encoder_hidden_states.to(self.device)
|
410 |
-
text_mask = text_mask.to(self.device)
|
411 |
-
text_encoder_hidden_states, text_mask, text_prompt_embeds = \
|
412 |
-
pad_or_tunc_tolen(text_encoder_hidden_states, text_mask, text_prompt_embeds)
|
413 |
-
text_encoder_hidden_states = text_encoder_hidden_states.permute(0,2,1).contiguous()
|
414 |
-
return text_encoder_hidden_states, text_mask
|
415 |
-
|
416 |
-
def extract_energy_bar(self, input_audios):
|
417 |
-
if(input_audios.shape[-1] % self.num_samples_perseg > 0):
|
418 |
-
energy_bar = input_audios[:,:-1 * (input_audios.shape[-1] % self.num_samples_perseg)].reshape(input_audios.shape[0],-1,self.num_samples_perseg)
|
419 |
-
else:
|
420 |
-
energy_bar = input_audios.reshape(input_audios.shape[0],-1,self.num_samples_perseg)
|
421 |
-
energy_bar = (energy_bar.pow(2.0).mean(-1).sqrt() + 1e-6).log10() * 20 # B T
|
422 |
-
energy_bar = (energy_bar / 2.0 + 16).clamp(0,16).int()
|
423 |
-
energy_embedding = self.energy_embedding(energy_bar)
|
424 |
-
energy_embedding = energy_embedding.view(energy_embedding.shape[0], energy_embedding.shape[1] // 2, 2, 32).reshape(energy_embedding.shape[0], energy_embedding.shape[1] // 2, 64).permute(0,2,1) # b 128 t
|
425 |
-
return energy_embedding
|
426 |
-
|
427 |
-
def forward(self, input_audios, lyric, latents, latent_masks, validation_mode=False, \
|
428 |
-
additional_feats = ['spk', 'lyric'], \
|
429 |
-
train_rvq=True, train_ssl=False,layer=5):
|
430 |
-
if not hasattr(self,"device"):
|
431 |
-
self.device = input_audios.device
|
432 |
-
if not hasattr(self,"dtype"):
|
433 |
-
self.dtype = input_audios.dtype
|
434 |
-
device = self.device
|
435 |
-
input_audio_0 = input_audios[:,0,:]
|
436 |
-
input_audio_1 = input_audios[:,1,:]
|
437 |
-
input_audio_0 = self.preprocess_audio(input_audio_0)
|
438 |
-
input_audio_1 = self.preprocess_audio(input_audio_1)
|
439 |
-
input_audios_wav2vec = (input_audio_0 + input_audio_1) / 2.0
|
440 |
-
# energy_embedding = self.extract_energy_bar(input_audios)
|
441 |
-
# print("energy_embedding.shape:",energy_embedding.shape)
|
442 |
-
# with autocast(enabled=False):
|
443 |
-
if(train_ssl):
|
444 |
-
self.wav2vec.train()
|
445 |
-
wav2vec_embeds = self.extract_wav2vec_embeds(input_audios)
|
446 |
-
self.clap_embd_extractor.train()
|
447 |
-
prompt_embeds = self.extract_mert_embeds(input_audios)
|
448 |
-
if('spk' in additional_feats):
|
449 |
-
self.xvecmodel.train()
|
450 |
-
spk_embeds = self.extract_spk_embeds(input_audios).repeat(1,1,prompt_embeds.shape[-1]//2,1)
|
451 |
-
else:
|
452 |
-
with torch.no_grad():
|
453 |
-
with autocast(enabled=False):
|
454 |
-
bestrq_emb = self.extract_bestrq_embeds(input_audio_0,input_audio_1,layer)
|
455 |
-
# mert_emb = self.extract_mert_embeds(input_audios_mert)
|
456 |
-
|
457 |
-
wav2vec_embeds = self.extract_wav2vec_embeds(input_audios_wav2vec,bestrq_emb.shape[2])
|
458 |
-
|
459 |
-
bestrq_emb = bestrq_emb.detach()
|
460 |
-
if('lyric' in additional_feats):
|
461 |
-
text_encoder_hidden_states, text_mask = self.extract_lyric_feats(lyric)
|
462 |
-
else:
|
463 |
-
text_encoder_hidden_states, text_mask = None, None
|
464 |
-
|
465 |
-
# prompt_embeds_13 = torch.cat([mert_emb_13, energy_embedding], 1)
|
466 |
-
# print("prompt_embes.shape:",prompt_embeds.shape)
|
467 |
-
#prompt_embes.shape: torch.Size([3, 1088, 896])
|
468 |
-
# print("wav2vec_embeds.shape:",wav2vec_embeds.shape)
|
469 |
-
#wav2vec_embeds.shape:torch.Size([3, 1024, 896])
|
470 |
-
if(train_rvq):
|
471 |
-
quantized_bestrq_emb, _, _, commitment_loss_bestrq_emb, codebook_loss_bestrq_emb,_ = self.rvq_bestrq_emb(bestrq_emb) # b,d,t
|
472 |
-
else:
|
473 |
-
bestrq_emb = bestrq_emb.float()
|
474 |
-
self.rvq_bestrq_emb.eval()
|
475 |
-
# with autocast(enabled=False):
|
476 |
-
quantized_bestrq_emb, _, _, commitment_loss_bestrq_emb, codebook_loss_bestrq_emb,_ = self.rvq_bestrq_emb(bestrq_emb) # b,d,t
|
477 |
-
commitment_loss_bestrq_emb = commitment_loss_bestrq_emb.detach()
|
478 |
-
codebook_loss_bestrq_emb = codebook_loss_bestrq_emb.detach()
|
479 |
-
quantized_bestrq_emb = quantized_bestrq_emb.detach()
|
480 |
-
|
481 |
-
commitment_loss = commitment_loss_bestrq_emb
|
482 |
-
codebook_loss = codebook_loss_bestrq_emb
|
483 |
-
|
484 |
-
|
485 |
-
alpha=1
|
486 |
-
quantized_bestrq_emb = quantized_bestrq_emb * alpha + bestrq_emb * (1-alpha)
|
487 |
-
|
488 |
-
# print("quantized_bestrq_emb.shape:",quantized_bestrq_emb.shape)
|
489 |
-
# print("latent_masks.shape:",latent_masks.shape)
|
490 |
-
# quantized_bestrq_emb = torch.nn.functional.interpolate(quantized_bestrq_emb, size=(int(quantized_bestrq_emb.shape[-1]/999*937),), mode='linear', align_corners=True)
|
491 |
-
|
492 |
-
|
493 |
-
|
494 |
-
scenario = np.random.choice(['start_seg', 'other_seg'])
|
495 |
-
if(scenario == 'other_seg'):
|
496 |
-
for binx in range(input_audios.shape[0]):
|
497 |
-
# latent_masks[binx,0:64] = 1
|
498 |
-
latent_masks[binx,0:random.randint(64,128)] = 1
|
499 |
-
quantized_bestrq_emb = quantized_bestrq_emb.permute(0,2,1).contiguous()
|
500 |
-
# print("quantized_bestrq_emb.shape:",quantized_bestrq_emb.shape)
|
501 |
-
# print("quantized_bestrq_emb1.shape:",quantized_bestrq_emb.shape)
|
502 |
-
# print("latent_masks.shape:",latent_masks.shape)
|
503 |
-
quantized_bestrq_emb = (latent_masks > 0.5).unsqueeze(-1) * quantized_bestrq_emb \
|
504 |
-
+ (latent_masks < 0.5).unsqueeze(-1) * self.zero_cond_embedding1.reshape(1,1,1024)
|
505 |
-
|
506 |
-
|
507 |
-
|
508 |
-
|
509 |
-
if self.uncondition:
|
510 |
-
mask_indices = [k for k in range(quantized_bestrq_emb.shape[0]) if random.random() < 0.1]
|
511 |
-
if len(mask_indices) > 0:
|
512 |
-
quantized_bestrq_emb[mask_indices] = 0
|
513 |
-
# print("latents.shape:",latents.shape)
|
514 |
-
latents = latents.permute(0,2,1).contiguous()
|
515 |
-
latents = self.normfeat.project_sample(latents)
|
516 |
-
latents = latents.permute(0,2,1).contiguous()
|
517 |
-
incontext_latents = latents * ((latent_masks > 0.5) * (latent_masks < 1.5)).unsqueeze(-1).float()
|
518 |
-
attention_mask=(latent_masks > 0.5)
|
519 |
-
B, L = attention_mask.size()
|
520 |
-
attention_mask = attention_mask.view(B, 1, L)
|
521 |
-
attention_mask = attention_mask * attention_mask.transpose(-1, -2)
|
522 |
-
attention_mask = attention_mask.unsqueeze(1)
|
523 |
-
# print("incontext_latents.shape:",incontext_latents.shape)
|
524 |
-
# print("quantized_bestrq_emb.shape:",quantized_bestrq_emb.shape)
|
525 |
-
latent_mask_input = self.mask_emb(latent_masks)
|
526 |
-
#64+48+64+1024
|
527 |
-
loss,loss_re, loss_cos = self.cfm_wrapper.compute_loss(latents, [latent_mask_input,incontext_latents, quantized_bestrq_emb], latent_masks,attention_mask,wav2vec_embeds, validation_mode=validation_mode)
|
528 |
-
return loss,loss_re, loss_cos, commitment_loss.mean(), codebook_loss.mean()
|
529 |
-
|
530 |
-
def init_device_dtype(self, device, dtype):
|
531 |
-
self.device = device
|
532 |
-
self.dtype = dtype
|
533 |
-
|
534 |
-
@torch.no_grad()
|
535 |
-
def fetch_codes(self, input_audios, additional_feats,layer):
|
536 |
-
input_audio_0 = input_audios[[0],:]
|
537 |
-
input_audio_1 = input_audios[[1],:]
|
538 |
-
input_audio_0 = self.preprocess_audio(input_audio_0)
|
539 |
-
input_audio_1 = self.preprocess_audio(input_audio_1)
|
540 |
-
|
541 |
-
self.bestrq.eval()
|
542 |
-
|
543 |
-
# bestrq_middle,bestrq_last = self.extract_bestrq_embeds(input_audios)
|
544 |
-
# bestrq_middle = bestrq_middle.detach()
|
545 |
-
# bestrq_last = bestrq_last.detach()
|
546 |
-
bestrq_emb = self.extract_bestrq_embeds(input_audio_0,input_audio_1,layer)
|
547 |
-
bestrq_emb = bestrq_emb.detach()
|
548 |
-
|
549 |
-
# self.rvq_bestrq_middle.eval()
|
550 |
-
# quantized_bestrq_middle, codes_bestrq_middle, *_ = self.rvq_bestrq_middle(bestrq_middle) # b,d,t
|
551 |
-
# self.rvq_bestrq_last.eval()
|
552 |
-
# quantized_bestrq_last, codes_bestrq_last, *_ = self.rvq_bestrq_last(bestrq_last) # b,d,t
|
553 |
-
|
554 |
-
self.rvq_bestrq_emb.eval()
|
555 |
-
quantized_bestrq_emb, codes_bestrq_emb, *_ = self.rvq_bestrq_emb(bestrq_emb) # b,d,t
|
556 |
-
|
557 |
-
|
558 |
-
if('spk' in additional_feats):
|
559 |
-
self.xvecmodel.eval()
|
560 |
-
spk_embeds = self.extract_spk_embeds(input_audios)
|
561 |
-
else:
|
562 |
-
spk_embeds = None
|
563 |
-
|
564 |
-
# return [codes_prompt, codes_wav2vec], [prompt_embeds, wav2vec_embeds], spk_embeds
|
565 |
-
# return [codes_prompt_7, codes_prompt_13, codes_prompt_20, codes_wav2vec_half, codes_wav2vec_last], [prompt_embeds_7, prompt_embeds_13, prompt_embeds_20, wav2vec_embeds_half, wav2vec_embeds_last], spk_embeds
|
566 |
-
# return [codes_bestrq_middle, codes_bestrq_last], [bestrq_middle, bestrq_last], spk_embeds
|
567 |
-
return [codes_bestrq_emb], [bestrq_emb], spk_embeds
|
568 |
-
# return [codes_prompt_13, codes_wav2vec_last], [prompt_embeds_13, wav2vec_embeds_last], spk_embeds
|
569 |
-
|
570 |
-
|
571 |
-
@torch.no_grad()
|
572 |
-
def fetch_codes_batch(self, input_audios, additional_feats,layer):
|
573 |
-
input_audio_0 = input_audios[:,0,:]
|
574 |
-
input_audio_1 = input_audios[:,1,:]
|
575 |
-
input_audio_0 = self.preprocess_audio(input_audio_0)
|
576 |
-
input_audio_1 = self.preprocess_audio(input_audio_1)
|
577 |
-
|
578 |
-
self.bestrq.eval()
|
579 |
-
|
580 |
-
# bestrq_middle,bestrq_last = self.extract_bestrq_embeds(input_audios)
|
581 |
-
# bestrq_middle = bestrq_middle.detach()
|
582 |
-
# bestrq_last = bestrq_last.detach()
|
583 |
-
bestrq_emb = self.extract_bestrq_embeds(input_audio_0,input_audio_1,layer)
|
584 |
-
bestrq_emb = bestrq_emb.detach()
|
585 |
-
|
586 |
-
# self.rvq_bestrq_middle.eval()
|
587 |
-
# quantized_bestrq_middle, codes_bestrq_middle, *_ = self.rvq_bestrq_middle(bestrq_middle) # b,d,t
|
588 |
-
# self.rvq_bestrq_last.eval()
|
589 |
-
# quantized_bestrq_last, codes_bestrq_last, *_ = self.rvq_bestrq_last(bestrq_last) # b,d,t
|
590 |
-
|
591 |
-
self.rvq_bestrq_emb.eval()
|
592 |
-
quantized_bestrq_emb, codes_bestrq_emb, *_ = self.rvq_bestrq_emb(bestrq_emb) # b,d,t
|
593 |
-
|
594 |
-
|
595 |
-
if('spk' in additional_feats):
|
596 |
-
self.xvecmodel.eval()
|
597 |
-
spk_embeds = self.extract_spk_embeds(input_audios)
|
598 |
-
else:
|
599 |
-
spk_embeds = None
|
600 |
-
|
601 |
-
# return [codes_prompt, codes_wav2vec], [prompt_embeds, wav2vec_embeds], spk_embeds
|
602 |
-
# return [codes_prompt_7, codes_prompt_13, codes_prompt_20, codes_wav2vec_half, codes_wav2vec_last], [prompt_embeds_7, prompt_embeds_13, prompt_embeds_20, wav2vec_embeds_half, wav2vec_embeds_last], spk_embeds
|
603 |
-
# return [codes_bestrq_middle, codes_bestrq_last], [bestrq_middle, bestrq_last], spk_embeds
|
604 |
-
return [codes_bestrq_emb], [bestrq_emb], spk_embeds
|
605 |
-
|
606 |
-
@torch.no_grad()
|
607 |
-
def inference_codes(self, codes, spk_embeds, true_latents, latent_length, additional_feats, incontext_length=127,
|
608 |
-
guidance_scale=2, num_steps=20,
|
609 |
-
disable_progress=True, scenario='start_seg'):
|
610 |
-
classifier_free_guidance = guidance_scale > 1.0
|
611 |
-
device = self.device
|
612 |
-
dtype = self.dtype
|
613 |
-
# codes_bestrq_middle, codes_bestrq_last = codes
|
614 |
-
codes_bestrq_emb = codes[0]
|
615 |
-
|
616 |
-
|
617 |
-
batch_size = codes_bestrq_emb.shape[0]
|
618 |
-
|
619 |
-
|
620 |
-
quantized_bestrq_emb,_,_=self.rvq_bestrq_emb.from_codes(codes_bestrq_emb)
|
621 |
-
# quantized_bestrq_emb = torch.nn.functional.interpolate(quantized_bestrq_emb, size=(int(quantized_bestrq_emb.shape[-1]/999*937),), mode='linear', align_corners=True)
|
622 |
-
quantized_bestrq_emb = quantized_bestrq_emb.permute(0,2,1).contiguous()
|
623 |
-
print("quantized_bestrq_emb.shape:",quantized_bestrq_emb.shape)
|
624 |
-
# quantized_bestrq_emb = torch.nn.functional.interpolate(quantized_bestrq_emb, size=(int(quantized_bestrq_emb.shape[-1]/999*937),), mode='linear', align_corners=True)
|
625 |
-
|
626 |
-
|
627 |
-
|
628 |
-
|
629 |
-
if('spk' in additional_feats):
|
630 |
-
spk_embeds = spk_embeds.repeat(1,1,quantized_bestrq_emb.shape[-2],1).detach()
|
631 |
-
|
632 |
-
num_frames = quantized_bestrq_emb.shape[1]
|
633 |
-
|
634 |
-
num_channels_latents = self.num_channels
|
635 |
-
shape = (batch_size, num_frames, 64)
|
636 |
-
latents = randn_tensor(shape, generator=None, device=device, dtype=dtype)
|
637 |
-
|
638 |
-
|
639 |
-
|
640 |
-
latent_masks = torch.zeros(latents.shape[0], latents.shape[1], dtype=torch.int64, device=latents.device)
|
641 |
-
latent_masks[:,0:latent_length] = 2
|
642 |
-
if(scenario=='other_seg'):
|
643 |
-
latent_masks[:,0:incontext_length] = 1
|
644 |
-
|
645 |
-
|
646 |
-
|
647 |
-
quantized_bestrq_emb = (latent_masks > 0.5).unsqueeze(-1) * quantized_bestrq_emb \
|
648 |
-
+ (latent_masks < 0.5).unsqueeze(-1) * self.zero_cond_embedding1.reshape(1,1,1024)
|
649 |
-
true_latents = true_latents.permute(0,2,1).contiguous()
|
650 |
-
true_latents = self.normfeat.project_sample(true_latents)
|
651 |
-
true_latents = true_latents.permute(0,2,1).contiguous()
|
652 |
-
incontext_latents = true_latents * ((latent_masks > 0.5) * (latent_masks < 1.5)).unsqueeze(-1).float()
|
653 |
-
incontext_length = ((latent_masks > 0.5) * (latent_masks < 1.5)).sum(-1)[0]
|
654 |
-
|
655 |
-
|
656 |
-
attention_mask=(latent_masks > 0.5)
|
657 |
-
B, L = attention_mask.size()
|
658 |
-
attention_mask = attention_mask.view(B, 1, L)
|
659 |
-
attention_mask = attention_mask * attention_mask.transpose(-1, -2)
|
660 |
-
attention_mask = attention_mask.unsqueeze(1)
|
661 |
-
latent_mask_input = self.mask_emb(latent_masks)
|
662 |
-
|
663 |
-
if('spk' in additional_feats):
|
664 |
-
# additional_model_input = torch.cat([quantized_bestrq_middle, quantized_bestrq_last, spk_embeds],1)
|
665 |
-
additional_model_input = torch.cat([quantized_bestrq_emb, spk_embeds],1)
|
666 |
-
else:
|
667 |
-
# additional_model_input = torch.cat([quantized_bestrq_middle, quantized_bestrq_last],1)
|
668 |
-
additional_model_input = torch.cat([quantized_bestrq_emb],1)
|
669 |
-
|
670 |
-
temperature = 1.0
|
671 |
-
t_span = torch.linspace(0, 1, num_steps + 1, device=quantized_bestrq_emb.device)
|
672 |
-
latents = self.cfm_wrapper.solve_euler(latents * temperature, latent_mask_input,incontext_latents, incontext_length, t_span, additional_model_input,attention_mask, guidance_scale)
|
673 |
-
|
674 |
-
latents[:,0:incontext_length,:] = incontext_latents[:,0:incontext_length,:]
|
675 |
-
latents = latents.permute(0,2,1).contiguous()
|
676 |
-
latents = self.normfeat.return_sample(latents)
|
677 |
-
# latents = latents.permute(0,2,1).contiguous()
|
678 |
-
return latents
|
679 |
-
|
680 |
-
@torch.no_grad()
|
681 |
-
def inference(self, input_audios, lyric, true_latents, latent_length, additional_feats, guidance_scale=2, num_steps=20,
|
682 |
-
disable_progress=True,layer=5,scenario='start_seg'):
|
683 |
-
codes, embeds, spk_embeds = self.fetch_codes(input_audios, additional_feats,layer)
|
684 |
-
|
685 |
-
latents = self.inference_codes(codes, spk_embeds, true_latents, latent_length, additional_feats, \
|
686 |
-
guidance_scale=guidance_scale, num_steps=num_steps, \
|
687 |
-
disable_progress=disable_progress,scenario=scenario)
|
688 |
-
return latents
|
689 |
-
|
690 |
-
@torch.no_grad()
|
691 |
-
def inference_rtf(self, input_audios, lyric, true_latents, latent_length, additional_feats, guidance_scale=2, num_steps=20,
|
692 |
-
disable_progress=True,layer=5,scenario='start_seg'):
|
693 |
-
codes, embeds, spk_embeds = self.fetch_codes(input_audios, additional_feats,layer)
|
694 |
-
import time
|
695 |
-
start = time.time()
|
696 |
-
latents = self.inference_codes(codes, spk_embeds, true_latents, latent_length, additional_feats, \
|
697 |
-
guidance_scale=guidance_scale, num_steps=num_steps, \
|
698 |
-
disable_progress=disable_progress,scenario=scenario)
|
699 |
-
return latents,time.time()-start
|
700 |
-
|
701 |
-
def prepare_latents(self, batch_size, num_frames, num_channels_latents, dtype, device):
|
702 |
-
divisor = 4
|
703 |
-
shape = (batch_size, num_channels_latents, num_frames, 32)
|
704 |
-
if(num_frames%divisor>0):
|
705 |
-
num_frames = round(num_frames/float(divisor))*divisor
|
706 |
-
shape = (batch_size, num_channels_latents, num_frames, 32)
|
707 |
-
latents = randn_tensor(shape, generator=None, device=device, dtype=dtype)
|
708 |
-
return latents
|
709 |
-
|
710 |
-
|
|
|
1 |
+
import yaml
|
2 |
+
import random
|
3 |
+
import inspect
|
4 |
+
import numpy as np
|
5 |
+
from tqdm import tqdm
|
6 |
+
import typing as tp
|
7 |
+
from abc import ABC
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
import torch.nn.functional as F
|
12 |
+
import torchaudio
|
13 |
+
|
14 |
+
from tools.torch_tools import wav_to_fbank
|
15 |
+
|
16 |
+
from diffusers.utils.torch_utils import randn_tensor
|
17 |
+
from transformers import HubertModel
|
18 |
+
from libs.rvq.descript_quantize3 import ResidualVectorQuantize
|
19 |
+
|
20 |
+
from models_gpt.models.gpt2_rope2_time_new_correct_mask_noncasual_reflow import GPT2Model
|
21 |
+
from models_gpt.models.gpt2_config import GPT2Config
|
22 |
+
|
23 |
+
from torch.cuda.amp import autocast
|
24 |
+
|
25 |
+
|
26 |
+
from our_MERT_BESTRQ.test import load_model
|
27 |
+
|
28 |
+
class HubertModelWithFinalProj(HubertModel):
|
29 |
+
def __init__(self, config):
|
30 |
+
super().__init__(config)
|
31 |
+
|
32 |
+
# The final projection layer is only used for backward compatibility.
|
33 |
+
# Following https://github.com/auspicious3000/contentvec/issues/6
|
34 |
+
# Remove this layer is necessary to achieve the desired outcome.
|
35 |
+
print("hidden_size:",config.hidden_size)
|
36 |
+
print("classifier_proj_size:",config.classifier_proj_size)
|
37 |
+
self.final_proj = nn.Linear(config.hidden_size, config.classifier_proj_size)
|
38 |
+
|
39 |
+
|
40 |
+
class SampleProcessor(torch.nn.Module):
|
41 |
+
def project_sample(self, x: torch.Tensor):
|
42 |
+
"""Project the original sample to the 'space' where the diffusion will happen."""
|
43 |
+
"""Project back from diffusion space to the actual sample space."""
|
44 |
+
return z
|
45 |
+
|
46 |
+
class Feature1DProcessor(SampleProcessor):
|
47 |
+
def __init__(self, dim: int = 100, power_std = 1., \
|
48 |
+
num_samples: int = 100_000, cal_num_frames: int = 600):
|
49 |
+
super().__init__()
|
50 |
+
|
51 |
+
self.num_samples = num_samples
|
52 |
+
self.dim = dim
|
53 |
+
self.power_std = power_std
|
54 |
+
self.cal_num_frames = cal_num_frames
|
55 |
+
self.register_buffer('counts', torch.zeros(1))
|
56 |
+
self.register_buffer('sum_x', torch.zeros(dim))
|
57 |
+
self.register_buffer('sum_x2', torch.zeros(dim))
|
58 |
+
self.register_buffer('sum_target_x2', torch.zeros(dim))
|
59 |
+
self.counts: torch.Tensor
|
60 |
+
self.sum_x: torch.Tensor
|
61 |
+
self.sum_x2: torch.Tensor
|
62 |
+
|
63 |
+
@property
|
64 |
+
def mean(self):
|
65 |
+
mean = self.sum_x / self.counts
|
66 |
+
if(self.counts < 10):
|
67 |
+
mean = torch.zeros_like(mean)
|
68 |
+
return mean
|
69 |
+
|
70 |
+
@property
|
71 |
+
def std(self):
|
72 |
+
std = (self.sum_x2 / self.counts - self.mean**2).clamp(min=0).sqrt()
|
73 |
+
if(self.counts < 10):
|
74 |
+
std = torch.ones_like(std)
|
75 |
+
return std
|
76 |
+
|
77 |
+
@property
|
78 |
+
def target_std(self):
|
79 |
+
return 1
|
80 |
+
|
81 |
+
def project_sample(self, x: torch.Tensor):
|
82 |
+
assert x.dim() == 3
|
83 |
+
if self.counts.item() < self.num_samples:
|
84 |
+
self.counts += len(x)
|
85 |
+
self.sum_x += x[:,:,0:self.cal_num_frames].mean(dim=(2,)).sum(dim=0)
|
86 |
+
self.sum_x2 += x[:,:,0:self.cal_num_frames].pow(2).mean(dim=(2,)).sum(dim=0)
|
87 |
+
rescale = (self.target_std / self.std.clamp(min=1e-12)) ** self.power_std # same output size
|
88 |
+
x = (x - self.mean.view(1, -1, 1)) * rescale.view(1, -1, 1)
|
89 |
+
return x
|
90 |
+
|
91 |
+
def return_sample(self, x: torch.Tensor):
|
92 |
+
assert x.dim() == 3
|
93 |
+
rescale = (self.std / self.target_std) ** self.power_std
|
94 |
+
# print(rescale, self.mean)
|
95 |
+
x = x * rescale.view(1, -1, 1) + self.mean.view(1, -1, 1)
|
96 |
+
return x
|
97 |
+
|
98 |
+
def pad_or_tunc_tolen(prior_text_encoder_hidden_states, prior_text_mask, prior_prompt_embeds, len_size=77):
|
99 |
+
if(prior_text_encoder_hidden_states.shape[1]<len_size):
|
100 |
+
prior_text_encoder_hidden_states = torch.cat([prior_text_encoder_hidden_states, \
|
101 |
+
torch.zeros(prior_text_mask.shape[0], len_size-prior_text_mask.shape[1], \
|
102 |
+
prior_text_encoder_hidden_states.shape[2], device=prior_text_mask.device, \
|
103 |
+
dtype=prior_text_encoder_hidden_states.dtype)],1)
|
104 |
+
prior_text_mask = torch.cat([prior_text_mask, torch.zeros(prior_text_mask.shape[0], len_size-prior_text_mask.shape[1], device=prior_text_mask.device, dtype=prior_text_mask.dtype)],1)
|
105 |
+
else:
|
106 |
+
prior_text_encoder_hidden_states = prior_text_encoder_hidden_states[:,0:len_size]
|
107 |
+
prior_text_mask = prior_text_mask[:,0:len_size]
|
108 |
+
prior_text_encoder_hidden_states = prior_text_encoder_hidden_states.permute(0,2,1).contiguous()
|
109 |
+
return prior_text_encoder_hidden_states, prior_text_mask, prior_prompt_embeds
|
110 |
+
|
111 |
+
class BASECFM(torch.nn.Module, ABC):
|
112 |
+
def __init__(
|
113 |
+
self,
|
114 |
+
estimator,
|
115 |
+
mlp,
|
116 |
+
ssl_layer
|
117 |
+
):
|
118 |
+
super().__init__()
|
119 |
+
self.sigma_min = 1e-4
|
120 |
+
|
121 |
+
self.estimator = estimator
|
122 |
+
self.mlp = mlp
|
123 |
+
self.ssl_layer = ssl_layer
|
124 |
+
|
125 |
+
@torch.inference_mode()
|
126 |
+
def forward(self, mu, n_timesteps, temperature=1.0):
|
127 |
+
"""Forward diffusion
|
128 |
+
|
129 |
+
Args:
|
130 |
+
mu (torch.Tensor): output of encoder
|
131 |
+
shape: (batch_size, n_channels, mel_timesteps, n_feats)
|
132 |
+
n_timesteps (int): number of diffusion steps
|
133 |
+
temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
|
134 |
+
|
135 |
+
Returns:
|
136 |
+
sample: generated mel-spectrogram
|
137 |
+
shape: (batch_size, n_channels, mel_timesteps, n_feats)
|
138 |
+
"""
|
139 |
+
z = torch.randn_like(mu) * temperature
|
140 |
+
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device)
|
141 |
+
return self.solve_euler(z, t_span=t_span)
|
142 |
+
|
143 |
+
def solve_euler(self, x, latent_mask_input,incontext_x, incontext_length, t_span, mu,attention_mask, guidance_scale):
|
144 |
+
"""
|
145 |
+
Fixed euler solver for ODEs.
|
146 |
+
Args:
|
147 |
+
x (torch.Tensor): random noise
|
148 |
+
t_span (torch.Tensor): n_timesteps interpolated
|
149 |
+
shape: (n_timesteps + 1,)
|
150 |
+
mu (torch.Tensor): output of encoder
|
151 |
+
shape: (batch_size, n_channels, mel_timesteps, n_feats)
|
152 |
+
"""
|
153 |
+
t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
|
154 |
+
noise = x.clone()
|
155 |
+
|
156 |
+
# I am storing this because I can later plot it by putting a debugger here and saving it to a file
|
157 |
+
# Or in future might add like a return_all_steps flag
|
158 |
+
sol = []
|
159 |
+
|
160 |
+
for step in tqdm(range(1, len(t_span))):
|
161 |
+
# print("incontext_x.shape:",incontext_x.shape)
|
162 |
+
# print("noise.shape:",noise.shape)
|
163 |
+
# print("t.shape:",t.shape)
|
164 |
+
x[:,0:incontext_length,:] = (1 - (1 - self.sigma_min) * t) * noise[:,0:incontext_length,:] + t * incontext_x[:,0:incontext_length,:]
|
165 |
+
if(guidance_scale > 1.0):
|
166 |
+
|
167 |
+
model_input = torch.cat([ \
|
168 |
+
torch.cat([latent_mask_input, latent_mask_input], 0), \
|
169 |
+
torch.cat([incontext_x, incontext_x], 0), \
|
170 |
+
torch.cat([torch.zeros_like(mu), mu], 0), \
|
171 |
+
torch.cat([x, x], 0), \
|
172 |
+
], 2)
|
173 |
+
timestep=t.unsqueeze(-1).repeat(2)
|
174 |
+
|
175 |
+
dphi_dt = self.estimator(inputs_embeds=model_input, attention_mask=attention_mask,time_step=timestep).last_hidden_state
|
176 |
+
dphi_dt_uncond, dhpi_dt_cond = dphi_dt.chunk(2,0)
|
177 |
+
dphi_dt = dphi_dt_uncond + guidance_scale * (dhpi_dt_cond - dphi_dt_uncond)
|
178 |
+
else:
|
179 |
+
model_input = torch.cat([latent_mask_input, incontext_x, mu, x], 2)
|
180 |
+
timestep=t.unsqueeze(-1)
|
181 |
+
dphi_dt = self.estimator(inputs_embeds=model_input, attention_mask=attention_mask,time_step=timestep).last_hidden_state
|
182 |
+
|
183 |
+
dphi_dt = dphi_dt[: ,:, -x.shape[2]:]
|
184 |
+
# print("dphi_dt.shape:",dphi_dt.shape)
|
185 |
+
# print("x.shape:",x.shape)
|
186 |
+
|
187 |
+
x = x + dt * dphi_dt
|
188 |
+
t = t + dt
|
189 |
+
sol.append(x)
|
190 |
+
if step < len(t_span) - 1:
|
191 |
+
dt = t_span[step + 1] - t
|
192 |
+
|
193 |
+
return sol[-1]
|
194 |
+
|
195 |
+
def projection_loss(self,hidden_proj, bestrq_emb):
|
196 |
+
bsz = hidden_proj.shape[0]
|
197 |
+
|
198 |
+
hidden_proj_normalized = F.normalize(hidden_proj, dim=-1)
|
199 |
+
bestrq_emb_normalized = F.normalize(bestrq_emb, dim=-1)
|
200 |
+
|
201 |
+
proj_loss = -(hidden_proj_normalized * bestrq_emb_normalized).sum(dim=-1)
|
202 |
+
proj_loss = 1+proj_loss.mean()
|
203 |
+
|
204 |
+
return proj_loss
|
205 |
+
|
206 |
+
def compute_loss(self, x1, mu, latent_masks,attention_mask,wav2vec_embeds, validation_mode=False):
|
207 |
+
"""Computes diffusion loss
|
208 |
+
|
209 |
+
Args:
|
210 |
+
x1 (torch.Tensor): Target
|
211 |
+
shape: (batch_size, n_channels, mel_timesteps, n_feats)
|
212 |
+
mu (torch.Tensor): output of encoder
|
213 |
+
shape: (batch_size, n_channels, mel_timesteps, n_feats)
|
214 |
+
|
215 |
+
Returns:
|
216 |
+
loss: conditional flow matching loss
|
217 |
+
y: conditional flow
|
218 |
+
shape: (batch_size, n_channels, mel_timesteps, n_feats)
|
219 |
+
"""
|
220 |
+
b = mu[0].shape[0]
|
221 |
+
len_x = x1.shape[2]
|
222 |
+
# random timestep
|
223 |
+
if(validation_mode):
|
224 |
+
t = torch.ones([b, 1, 1], device=mu[0].device, dtype=mu[0].dtype) * 0.5
|
225 |
+
else:
|
226 |
+
t = torch.rand([b, 1, 1], device=mu[0].device, dtype=mu[0].dtype)
|
227 |
+
# sample noise p(x_0)
|
228 |
+
z = torch.randn_like(x1)
|
229 |
+
|
230 |
+
y = (1 - (1 - self.sigma_min) * t) * z + t * x1
|
231 |
+
u = x1 - (1 - self.sigma_min) * z
|
232 |
+
# print("y.shape:",y.shape)
|
233 |
+
#self.unet(inputs_embeds=model_input, attention_mask=attention_mask,encoder_hidden_states=text_embedding,encoder_attention_mask=txt_attn_mask,time_step=timesteps).last_hidden_state
|
234 |
+
model_input = torch.cat([*mu,y], 2)
|
235 |
+
t=t.squeeze(-1).squeeze(-1)
|
236 |
+
# print("model_input.shape:",model_input.shape)
|
237 |
+
# print("attention_mask.shape:",attention_mask.shape)
|
238 |
+
out = self.estimator(inputs_embeds=model_input, attention_mask=attention_mask,time_step=t,output_hidden_states=True)
|
239 |
+
hidden_layer = out.hidden_states[self.ssl_layer]
|
240 |
+
hidden_proj = self.mlp(hidden_layer)
|
241 |
+
# print("hidden_proj.shape:",hidden_proj.shape)
|
242 |
+
# print("mert_emb.shape:",mert_emb.shape)
|
243 |
+
# exit()
|
244 |
+
|
245 |
+
|
246 |
+
out = out.last_hidden_state
|
247 |
+
|
248 |
+
out=out[:,:,-len_x:]
|
249 |
+
# out=self.proj_out(out)
|
250 |
+
|
251 |
+
weight = (latent_masks > 1.5).unsqueeze(-1).repeat(1, 1, out.shape[-1]).float() + (latent_masks < 0.5).unsqueeze(-1).repeat(1, 1, out.shape[-1]).float() * 0.01
|
252 |
+
# print("out.shape",out.shape)
|
253 |
+
# print("u.shape",u.shape)
|
254 |
+
loss_re = F.mse_loss(out * weight, u * weight, reduction="sum") / weight.sum()
|
255 |
+
# print("hidden_proj.shape:",hidden_proj.shape)
|
256 |
+
# print("wav2vec_embeds.shape:",wav2vec_embeds.shape)
|
257 |
+
loss_cos = self.projection_loss(hidden_proj, wav2vec_embeds)
|
258 |
+
loss = loss_re + loss_cos * 0.5
|
259 |
+
# print("loss_cos:",loss_cos,loss_cos.device)
|
260 |
+
print("loss:",loss,loss.device)
|
261 |
+
# exit()
|
262 |
+
return loss, loss_re, loss_cos
|
263 |
+
|
264 |
+
class PromptCondAudioDiffusion(nn.Module):
|
265 |
+
def __init__(
|
266 |
+
self,
|
267 |
+
num_channels,
|
268 |
+
unet_model_name=None,
|
269 |
+
unet_model_config_path=None,
|
270 |
+
snr_gamma=None,
|
271 |
+
hubert_layer=None,
|
272 |
+
ssl_layer=None,
|
273 |
+
uncondition=True,
|
274 |
+
out_paint=False,
|
275 |
+
):
|
276 |
+
super().__init__()
|
277 |
+
|
278 |
+
assert unet_model_name is not None or unet_model_config_path is not None, "Either UNet pretrain model name or a config file path is required"
|
279 |
+
|
280 |
+
self.unet_model_name = unet_model_name
|
281 |
+
self.unet_model_config_path = unet_model_config_path
|
282 |
+
self.snr_gamma = snr_gamma
|
283 |
+
self.uncondition = uncondition
|
284 |
+
self.num_channels = num_channels
|
285 |
+
self.hubert_layer = hubert_layer
|
286 |
+
self.ssl_layer = ssl_layer
|
287 |
+
|
288 |
+
# https://huggingface.co/docs/diffusers/v0.14.0/en/api/schedulers/overview
|
289 |
+
self.normfeat = Feature1DProcessor(dim=64)
|
290 |
+
|
291 |
+
self.sample_rate = 48000
|
292 |
+
self.num_samples_perseg = self.sample_rate * 20 // 1000
|
293 |
+
self.rsp48toclap = torchaudio.transforms.Resample(48000, 24000)
|
294 |
+
self.rsq48towav2vec = torchaudio.transforms.Resample(48000, 16000)
|
295 |
+
# self.wav2vec = Wav2Vec2BertModel.from_pretrained("facebook/w2v-bert-2.0", trust_remote_code=True)
|
296 |
+
# self.wav2vec_processor = AutoFeatureExtractor.from_pretrained("facebook/w2v-bert-2.0", trust_remote_code=True)
|
297 |
+
self.bestrq = load_model(
|
298 |
+
model_dir='codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq',
|
299 |
+
checkpoint_dir='ckpt/encode-s12k.pt',
|
300 |
+
)
|
301 |
+
self.rsq48tobestrq = torchaudio.transforms.Resample(48000, 24000)
|
302 |
+
self.rsq48tohubert = torchaudio.transforms.Resample(48000, 16000)
|
303 |
+
for v in self.bestrq.parameters():v.requires_grad = False
|
304 |
+
self.rvq_bestrq_emb = ResidualVectorQuantize(input_dim = 1024, n_codebooks = 1, codebook_size = 16_384, codebook_dim = 32, quantizer_dropout = 0.0, stale_tolerance=200)
|
305 |
+
for v in self.rvq_bestrq_emb.parameters():v.requires_grad = False
|
306 |
+
self.hubert = HubertModelWithFinalProj.from_pretrained("ckpt/models--lengyue233--content-vec-best/snapshots/c0b9ba13db21beaa4053faae94c102ebe326fd68")
|
307 |
+
for v in self.hubert.parameters():v.requires_grad = False
|
308 |
+
self.zero_cond_embedding1 = nn.Parameter(torch.randn(32*32,))
|
309 |
+
# self.xvecmodel = XVECModel()
|
310 |
+
config = GPT2Config(n_positions=1000,n_layer=39,n_head=30,n_embd=1200)
|
311 |
+
unet = GPT2Model(config)
|
312 |
+
mlp = nn.Sequential(
|
313 |
+
nn.Linear(1200, 1024),
|
314 |
+
nn.SiLU(),
|
315 |
+
nn.Linear(1024, 1024),
|
316 |
+
nn.SiLU(),
|
317 |
+
nn.Linear(1024, 768)
|
318 |
+
)
|
319 |
+
self.set_from = "random"
|
320 |
+
self.cfm_wrapper = BASECFM(unet, mlp,self.ssl_layer)
|
321 |
+
self.mask_emb = torch.nn.Embedding(3, 48)
|
322 |
+
print("Transformer initialized from pretrain.")
|
323 |
+
torch.cuda.empty_cache()
|
324 |
+
# self.unet.set_attn_processor(AttnProcessor2_0())
|
325 |
+
# self.unet.set_use_memory_efficient_attention_xformers(True)
|
326 |
+
|
327 |
+
# self.start_embedding = nn.Parameter(torch.randn(1,1024))
|
328 |
+
# self.end_embedding = nn.Parameter(torch.randn(1,1024))
|
329 |
+
|
330 |
+
def compute_snr(self, timesteps):
|
331 |
+
"""
|
332 |
+
Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
|
333 |
+
"""
|
334 |
+
alphas_cumprod = self.noise_scheduler.alphas_cumprod
|
335 |
+
sqrt_alphas_cumprod = alphas_cumprod**0.5
|
336 |
+
sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
|
337 |
+
|
338 |
+
# Expand the tensors.
|
339 |
+
# Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026
|
340 |
+
sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
|
341 |
+
while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape):
|
342 |
+
sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
|
343 |
+
alpha = sqrt_alphas_cumprod.expand(timesteps.shape)
|
344 |
+
|
345 |
+
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
|
346 |
+
while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape):
|
347 |
+
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None]
|
348 |
+
sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape)
|
349 |
+
|
350 |
+
# Compute SNR.
|
351 |
+
snr = (alpha / sigma) ** 2
|
352 |
+
return snr
|
353 |
+
|
354 |
+
def preprocess_audio(self, input_audios, threshold=0.9):
|
355 |
+
assert len(input_audios.shape) == 2, input_audios.shape
|
356 |
+
norm_value = torch.ones_like(input_audios[:,0])
|
357 |
+
max_volume = input_audios.abs().max(dim=-1)[0]
|
358 |
+
norm_value[max_volume>threshold] = max_volume[max_volume>threshold] / threshold
|
359 |
+
return input_audios/norm_value.unsqueeze(-1)
|
360 |
+
|
361 |
+
def extract_wav2vec_embeds(self, input_audios,output_len):
|
362 |
+
wav2vec_stride = 2
|
363 |
+
|
364 |
+
wav2vec_embeds = self.hubert(self.rsq48tohubert(input_audios), output_hidden_states=True).hidden_states # 1, 4096, 1024
|
365 |
+
# print(wav2vec_embeds)
|
366 |
+
# print("audio.shape:",input_audios.shape)
|
367 |
+
wav2vec_embeds_last=wav2vec_embeds[self.hubert_layer]
|
368 |
+
# print("wav2vec_embeds_last.shape:",wav2vec_embeds_last.shape)
|
369 |
+
wav2vec_embeds_last=torch.nn.functional.interpolate(wav2vec_embeds_last.permute(0, 2, 1), size=output_len, mode='linear', align_corners=False).permute(0, 2, 1)
|
370 |
+
return wav2vec_embeds_last
|
371 |
+
|
372 |
+
def extract_mert_embeds(self, input_audios):
|
373 |
+
prompt_stride = 3
|
374 |
+
inputs = self.clap_embd_extractor.mulan.audio.processor(self.rsp48toclap(input_audios), sampling_rate=self.clap_embd_extractor.mulan.audio.sr, return_tensors="pt")
|
375 |
+
input_values = inputs['input_values'].squeeze(0).to(input_audios.device, dtype = input_audios.dtype)
|
376 |
+
prompt_embeds = self.clap_embd_extractor.mulan.audio.model(input_values, output_hidden_states=True).hidden_states # batch_size, Time steps, 1024
|
377 |
+
mert_emb= prompt_embeds[-1]
|
378 |
+
mert_emb = torch.nn.functional.interpolate(mert_emb.permute(0, 2, 1), size=500, mode='linear', align_corners=False).permute(0, 2, 1)
|
379 |
+
|
380 |
+
return mert_emb
|
381 |
+
|
382 |
+
def extract_bestrq_embeds(self, input_audio_0,input_audio_1,layer):
|
383 |
+
self.bestrq.eval()
|
384 |
+
# print("audio shape:",input_audio_0.shape)
|
385 |
+
input_wav_mean = (input_audio_0 + input_audio_1) / 2.0
|
386 |
+
# print("input_wav_mean.shape:",input_wav_mean.shape)
|
387 |
+
# input_wav_mean = torch.randn(2,1720320*2).to(input_audio_0.device)
|
388 |
+
input_wav_mean = self.bestrq(self.rsq48tobestrq(input_wav_mean), features_only = True)
|
389 |
+
layer_results = input_wav_mean['layer_results']
|
390 |
+
# print("layer_results.shape:",layer_results[layer].shape)
|
391 |
+
bestrq_emb = layer_results[layer]
|
392 |
+
bestrq_emb = bestrq_emb.permute(0,2,1).contiguous()
|
393 |
+
#[b,t,1024] t=t/960
|
394 |
+
#35.84s->batch,896,1024
|
395 |
+
return bestrq_emb
|
396 |
+
|
397 |
+
|
398 |
+
def extract_spk_embeds(self, input_audios):
|
399 |
+
spk_embeds = self.xvecmodel(self.rsq48towav2vec(input_audios))
|
400 |
+
spk_embeds = self.spk_linear(spk_embeds).reshape(spk_embeds.shape[0], 16, 1, 32)
|
401 |
+
return spk_embeds
|
402 |
+
|
403 |
+
def extract_lyric_feats(self, lyric):
|
404 |
+
with torch.no_grad():
|
405 |
+
try:
|
406 |
+
text_encoder_hidden_states, text_mask, text_prompt_embeds = self.clap_embd_extractor(texts = lyric, return_one=False)
|
407 |
+
except:
|
408 |
+
text_encoder_hidden_states, text_mask, text_prompt_embeds = self.clap_embd_extractor(texts = [""] * len(lyric), return_one=False)
|
409 |
+
text_encoder_hidden_states = text_encoder_hidden_states.to(self.device)
|
410 |
+
text_mask = text_mask.to(self.device)
|
411 |
+
text_encoder_hidden_states, text_mask, text_prompt_embeds = \
|
412 |
+
pad_or_tunc_tolen(text_encoder_hidden_states, text_mask, text_prompt_embeds)
|
413 |
+
text_encoder_hidden_states = text_encoder_hidden_states.permute(0,2,1).contiguous()
|
414 |
+
return text_encoder_hidden_states, text_mask
|
415 |
+
|
416 |
+
def extract_energy_bar(self, input_audios):
|
417 |
+
if(input_audios.shape[-1] % self.num_samples_perseg > 0):
|
418 |
+
energy_bar = input_audios[:,:-1 * (input_audios.shape[-1] % self.num_samples_perseg)].reshape(input_audios.shape[0],-1,self.num_samples_perseg)
|
419 |
+
else:
|
420 |
+
energy_bar = input_audios.reshape(input_audios.shape[0],-1,self.num_samples_perseg)
|
421 |
+
energy_bar = (energy_bar.pow(2.0).mean(-1).sqrt() + 1e-6).log10() * 20 # B T
|
422 |
+
energy_bar = (energy_bar / 2.0 + 16).clamp(0,16).int()
|
423 |
+
energy_embedding = self.energy_embedding(energy_bar)
|
424 |
+
energy_embedding = energy_embedding.view(energy_embedding.shape[0], energy_embedding.shape[1] // 2, 2, 32).reshape(energy_embedding.shape[0], energy_embedding.shape[1] // 2, 64).permute(0,2,1) # b 128 t
|
425 |
+
return energy_embedding
|
426 |
+
|
427 |
+
def forward(self, input_audios, lyric, latents, latent_masks, validation_mode=False, \
|
428 |
+
additional_feats = ['spk', 'lyric'], \
|
429 |
+
train_rvq=True, train_ssl=False,layer=5):
|
430 |
+
if not hasattr(self,"device"):
|
431 |
+
self.device = input_audios.device
|
432 |
+
if not hasattr(self,"dtype"):
|
433 |
+
self.dtype = input_audios.dtype
|
434 |
+
device = self.device
|
435 |
+
input_audio_0 = input_audios[:,0,:]
|
436 |
+
input_audio_1 = input_audios[:,1,:]
|
437 |
+
input_audio_0 = self.preprocess_audio(input_audio_0)
|
438 |
+
input_audio_1 = self.preprocess_audio(input_audio_1)
|
439 |
+
input_audios_wav2vec = (input_audio_0 + input_audio_1) / 2.0
|
440 |
+
# energy_embedding = self.extract_energy_bar(input_audios)
|
441 |
+
# print("energy_embedding.shape:",energy_embedding.shape)
|
442 |
+
# with autocast(enabled=False):
|
443 |
+
if(train_ssl):
|
444 |
+
self.wav2vec.train()
|
445 |
+
wav2vec_embeds = self.extract_wav2vec_embeds(input_audios)
|
446 |
+
self.clap_embd_extractor.train()
|
447 |
+
prompt_embeds = self.extract_mert_embeds(input_audios)
|
448 |
+
if('spk' in additional_feats):
|
449 |
+
self.xvecmodel.train()
|
450 |
+
spk_embeds = self.extract_spk_embeds(input_audios).repeat(1,1,prompt_embeds.shape[-1]//2,1)
|
451 |
+
else:
|
452 |
+
with torch.no_grad():
|
453 |
+
with autocast(enabled=False):
|
454 |
+
bestrq_emb = self.extract_bestrq_embeds(input_audio_0,input_audio_1,layer)
|
455 |
+
# mert_emb = self.extract_mert_embeds(input_audios_mert)
|
456 |
+
|
457 |
+
wav2vec_embeds = self.extract_wav2vec_embeds(input_audios_wav2vec,bestrq_emb.shape[2])
|
458 |
+
|
459 |
+
bestrq_emb = bestrq_emb.detach()
|
460 |
+
if('lyric' in additional_feats):
|
461 |
+
text_encoder_hidden_states, text_mask = self.extract_lyric_feats(lyric)
|
462 |
+
else:
|
463 |
+
text_encoder_hidden_states, text_mask = None, None
|
464 |
+
|
465 |
+
# prompt_embeds_13 = torch.cat([mert_emb_13, energy_embedding], 1)
|
466 |
+
# print("prompt_embes.shape:",prompt_embeds.shape)
|
467 |
+
#prompt_embes.shape: torch.Size([3, 1088, 896])
|
468 |
+
# print("wav2vec_embeds.shape:",wav2vec_embeds.shape)
|
469 |
+
#wav2vec_embeds.shape:torch.Size([3, 1024, 896])
|
470 |
+
if(train_rvq):
|
471 |
+
quantized_bestrq_emb, _, _, commitment_loss_bestrq_emb, codebook_loss_bestrq_emb,_ = self.rvq_bestrq_emb(bestrq_emb) # b,d,t
|
472 |
+
else:
|
473 |
+
bestrq_emb = bestrq_emb.float()
|
474 |
+
self.rvq_bestrq_emb.eval()
|
475 |
+
# with autocast(enabled=False):
|
476 |
+
quantized_bestrq_emb, _, _, commitment_loss_bestrq_emb, codebook_loss_bestrq_emb,_ = self.rvq_bestrq_emb(bestrq_emb) # b,d,t
|
477 |
+
commitment_loss_bestrq_emb = commitment_loss_bestrq_emb.detach()
|
478 |
+
codebook_loss_bestrq_emb = codebook_loss_bestrq_emb.detach()
|
479 |
+
quantized_bestrq_emb = quantized_bestrq_emb.detach()
|
480 |
+
|
481 |
+
commitment_loss = commitment_loss_bestrq_emb
|
482 |
+
codebook_loss = codebook_loss_bestrq_emb
|
483 |
+
|
484 |
+
|
485 |
+
alpha=1
|
486 |
+
quantized_bestrq_emb = quantized_bestrq_emb * alpha + bestrq_emb * (1-alpha)
|
487 |
+
|
488 |
+
# print("quantized_bestrq_emb.shape:",quantized_bestrq_emb.shape)
|
489 |
+
# print("latent_masks.shape:",latent_masks.shape)
|
490 |
+
# quantized_bestrq_emb = torch.nn.functional.interpolate(quantized_bestrq_emb, size=(int(quantized_bestrq_emb.shape[-1]/999*937),), mode='linear', align_corners=True)
|
491 |
+
|
492 |
+
|
493 |
+
|
494 |
+
scenario = np.random.choice(['start_seg', 'other_seg'])
|
495 |
+
if(scenario == 'other_seg'):
|
496 |
+
for binx in range(input_audios.shape[0]):
|
497 |
+
# latent_masks[binx,0:64] = 1
|
498 |
+
latent_masks[binx,0:random.randint(64,128)] = 1
|
499 |
+
quantized_bestrq_emb = quantized_bestrq_emb.permute(0,2,1).contiguous()
|
500 |
+
# print("quantized_bestrq_emb.shape:",quantized_bestrq_emb.shape)
|
501 |
+
# print("quantized_bestrq_emb1.shape:",quantized_bestrq_emb.shape)
|
502 |
+
# print("latent_masks.shape:",latent_masks.shape)
|
503 |
+
quantized_bestrq_emb = (latent_masks > 0.5).unsqueeze(-1) * quantized_bestrq_emb \
|
504 |
+
+ (latent_masks < 0.5).unsqueeze(-1) * self.zero_cond_embedding1.reshape(1,1,1024)
|
505 |
+
|
506 |
+
|
507 |
+
|
508 |
+
|
509 |
+
if self.uncondition:
|
510 |
+
mask_indices = [k for k in range(quantized_bestrq_emb.shape[0]) if random.random() < 0.1]
|
511 |
+
if len(mask_indices) > 0:
|
512 |
+
quantized_bestrq_emb[mask_indices] = 0
|
513 |
+
# print("latents.shape:",latents.shape)
|
514 |
+
latents = latents.permute(0,2,1).contiguous()
|
515 |
+
latents = self.normfeat.project_sample(latents)
|
516 |
+
latents = latents.permute(0,2,1).contiguous()
|
517 |
+
incontext_latents = latents * ((latent_masks > 0.5) * (latent_masks < 1.5)).unsqueeze(-1).float()
|
518 |
+
attention_mask=(latent_masks > 0.5)
|
519 |
+
B, L = attention_mask.size()
|
520 |
+
attention_mask = attention_mask.view(B, 1, L)
|
521 |
+
attention_mask = attention_mask * attention_mask.transpose(-1, -2)
|
522 |
+
attention_mask = attention_mask.unsqueeze(1)
|
523 |
+
# print("incontext_latents.shape:",incontext_latents.shape)
|
524 |
+
# print("quantized_bestrq_emb.shape:",quantized_bestrq_emb.shape)
|
525 |
+
latent_mask_input = self.mask_emb(latent_masks)
|
526 |
+
#64+48+64+1024
|
527 |
+
loss,loss_re, loss_cos = self.cfm_wrapper.compute_loss(latents, [latent_mask_input,incontext_latents, quantized_bestrq_emb], latent_masks,attention_mask,wav2vec_embeds, validation_mode=validation_mode)
|
528 |
+
return loss,loss_re, loss_cos, commitment_loss.mean(), codebook_loss.mean()
|
529 |
+
|
530 |
+
def init_device_dtype(self, device, dtype):
|
531 |
+
self.device = device
|
532 |
+
self.dtype = dtype
|
533 |
+
|
534 |
+
@torch.no_grad()
|
535 |
+
def fetch_codes(self, input_audios, additional_feats,layer):
|
536 |
+
input_audio_0 = input_audios[[0],:]
|
537 |
+
input_audio_1 = input_audios[[1],:]
|
538 |
+
input_audio_0 = self.preprocess_audio(input_audio_0)
|
539 |
+
input_audio_1 = self.preprocess_audio(input_audio_1)
|
540 |
+
|
541 |
+
self.bestrq.eval()
|
542 |
+
|
543 |
+
# bestrq_middle,bestrq_last = self.extract_bestrq_embeds(input_audios)
|
544 |
+
# bestrq_middle = bestrq_middle.detach()
|
545 |
+
# bestrq_last = bestrq_last.detach()
|
546 |
+
bestrq_emb = self.extract_bestrq_embeds(input_audio_0,input_audio_1,layer)
|
547 |
+
bestrq_emb = bestrq_emb.detach()
|
548 |
+
|
549 |
+
# self.rvq_bestrq_middle.eval()
|
550 |
+
# quantized_bestrq_middle, codes_bestrq_middle, *_ = self.rvq_bestrq_middle(bestrq_middle) # b,d,t
|
551 |
+
# self.rvq_bestrq_last.eval()
|
552 |
+
# quantized_bestrq_last, codes_bestrq_last, *_ = self.rvq_bestrq_last(bestrq_last) # b,d,t
|
553 |
+
|
554 |
+
self.rvq_bestrq_emb.eval()
|
555 |
+
quantized_bestrq_emb, codes_bestrq_emb, *_ = self.rvq_bestrq_emb(bestrq_emb) # b,d,t
|
556 |
+
|
557 |
+
|
558 |
+
if('spk' in additional_feats):
|
559 |
+
self.xvecmodel.eval()
|
560 |
+
spk_embeds = self.extract_spk_embeds(input_audios)
|
561 |
+
else:
|
562 |
+
spk_embeds = None
|
563 |
+
|
564 |
+
# return [codes_prompt, codes_wav2vec], [prompt_embeds, wav2vec_embeds], spk_embeds
|
565 |
+
# return [codes_prompt_7, codes_prompt_13, codes_prompt_20, codes_wav2vec_half, codes_wav2vec_last], [prompt_embeds_7, prompt_embeds_13, prompt_embeds_20, wav2vec_embeds_half, wav2vec_embeds_last], spk_embeds
|
566 |
+
# return [codes_bestrq_middle, codes_bestrq_last], [bestrq_middle, bestrq_last], spk_embeds
|
567 |
+
return [codes_bestrq_emb], [bestrq_emb], spk_embeds
|
568 |
+
# return [codes_prompt_13, codes_wav2vec_last], [prompt_embeds_13, wav2vec_embeds_last], spk_embeds
|
569 |
+
|
570 |
+
|
571 |
+
@torch.no_grad()
|
572 |
+
def fetch_codes_batch(self, input_audios, additional_feats,layer):
|
573 |
+
input_audio_0 = input_audios[:,0,:]
|
574 |
+
input_audio_1 = input_audios[:,1,:]
|
575 |
+
input_audio_0 = self.preprocess_audio(input_audio_0)
|
576 |
+
input_audio_1 = self.preprocess_audio(input_audio_1)
|
577 |
+
|
578 |
+
self.bestrq.eval()
|
579 |
+
|
580 |
+
# bestrq_middle,bestrq_last = self.extract_bestrq_embeds(input_audios)
|
581 |
+
# bestrq_middle = bestrq_middle.detach()
|
582 |
+
# bestrq_last = bestrq_last.detach()
|
583 |
+
bestrq_emb = self.extract_bestrq_embeds(input_audio_0,input_audio_1,layer)
|
584 |
+
bestrq_emb = bestrq_emb.detach()
|
585 |
+
|
586 |
+
# self.rvq_bestrq_middle.eval()
|
587 |
+
# quantized_bestrq_middle, codes_bestrq_middle, *_ = self.rvq_bestrq_middle(bestrq_middle) # b,d,t
|
588 |
+
# self.rvq_bestrq_last.eval()
|
589 |
+
# quantized_bestrq_last, codes_bestrq_last, *_ = self.rvq_bestrq_last(bestrq_last) # b,d,t
|
590 |
+
|
591 |
+
self.rvq_bestrq_emb.eval()
|
592 |
+
quantized_bestrq_emb, codes_bestrq_emb, *_ = self.rvq_bestrq_emb(bestrq_emb) # b,d,t
|
593 |
+
|
594 |
+
|
595 |
+
if('spk' in additional_feats):
|
596 |
+
self.xvecmodel.eval()
|
597 |
+
spk_embeds = self.extract_spk_embeds(input_audios)
|
598 |
+
else:
|
599 |
+
spk_embeds = None
|
600 |
+
|
601 |
+
# return [codes_prompt, codes_wav2vec], [prompt_embeds, wav2vec_embeds], spk_embeds
|
602 |
+
# return [codes_prompt_7, codes_prompt_13, codes_prompt_20, codes_wav2vec_half, codes_wav2vec_last], [prompt_embeds_7, prompt_embeds_13, prompt_embeds_20, wav2vec_embeds_half, wav2vec_embeds_last], spk_embeds
|
603 |
+
# return [codes_bestrq_middle, codes_bestrq_last], [bestrq_middle, bestrq_last], spk_embeds
|
604 |
+
return [codes_bestrq_emb], [bestrq_emb], spk_embeds
|
605 |
+
|
606 |
+
@torch.no_grad()
|
607 |
+
def inference_codes(self, codes, spk_embeds, true_latents, latent_length, additional_feats, incontext_length=127,
|
608 |
+
guidance_scale=2, num_steps=20,
|
609 |
+
disable_progress=True, scenario='start_seg'):
|
610 |
+
classifier_free_guidance = guidance_scale > 1.0
|
611 |
+
device = self.device
|
612 |
+
dtype = self.dtype
|
613 |
+
# codes_bestrq_middle, codes_bestrq_last = codes
|
614 |
+
codes_bestrq_emb = codes[0]
|
615 |
+
|
616 |
+
|
617 |
+
batch_size = codes_bestrq_emb.shape[0]
|
618 |
+
|
619 |
+
|
620 |
+
quantized_bestrq_emb,_,_=self.rvq_bestrq_emb.from_codes(codes_bestrq_emb)
|
621 |
+
# quantized_bestrq_emb = torch.nn.functional.interpolate(quantized_bestrq_emb, size=(int(quantized_bestrq_emb.shape[-1]/999*937),), mode='linear', align_corners=True)
|
622 |
+
quantized_bestrq_emb = quantized_bestrq_emb.permute(0,2,1).contiguous()
|
623 |
+
print("quantized_bestrq_emb.shape:",quantized_bestrq_emb.shape)
|
624 |
+
# quantized_bestrq_emb = torch.nn.functional.interpolate(quantized_bestrq_emb, size=(int(quantized_bestrq_emb.shape[-1]/999*937),), mode='linear', align_corners=True)
|
625 |
+
|
626 |
+
|
627 |
+
|
628 |
+
|
629 |
+
if('spk' in additional_feats):
|
630 |
+
spk_embeds = spk_embeds.repeat(1,1,quantized_bestrq_emb.shape[-2],1).detach()
|
631 |
+
|
632 |
+
num_frames = quantized_bestrq_emb.shape[1]
|
633 |
+
|
634 |
+
num_channels_latents = self.num_channels
|
635 |
+
shape = (batch_size, num_frames, 64)
|
636 |
+
latents = randn_tensor(shape, generator=None, device=device, dtype=dtype)
|
637 |
+
|
638 |
+
|
639 |
+
|
640 |
+
latent_masks = torch.zeros(latents.shape[0], latents.shape[1], dtype=torch.int64, device=latents.device)
|
641 |
+
latent_masks[:,0:latent_length] = 2
|
642 |
+
if(scenario=='other_seg'):
|
643 |
+
latent_masks[:,0:incontext_length] = 1
|
644 |
+
|
645 |
+
|
646 |
+
|
647 |
+
quantized_bestrq_emb = (latent_masks > 0.5).unsqueeze(-1) * quantized_bestrq_emb \
|
648 |
+
+ (latent_masks < 0.5).unsqueeze(-1) * self.zero_cond_embedding1.reshape(1,1,1024)
|
649 |
+
true_latents = true_latents.permute(0,2,1).contiguous()
|
650 |
+
true_latents = self.normfeat.project_sample(true_latents)
|
651 |
+
true_latents = true_latents.permute(0,2,1).contiguous()
|
652 |
+
incontext_latents = true_latents * ((latent_masks > 0.5) * (latent_masks < 1.5)).unsqueeze(-1).float()
|
653 |
+
incontext_length = ((latent_masks > 0.5) * (latent_masks < 1.5)).sum(-1)[0]
|
654 |
+
|
655 |
+
|
656 |
+
attention_mask=(latent_masks > 0.5)
|
657 |
+
B, L = attention_mask.size()
|
658 |
+
attention_mask = attention_mask.view(B, 1, L)
|
659 |
+
attention_mask = attention_mask * attention_mask.transpose(-1, -2)
|
660 |
+
attention_mask = attention_mask.unsqueeze(1)
|
661 |
+
latent_mask_input = self.mask_emb(latent_masks)
|
662 |
+
|
663 |
+
if('spk' in additional_feats):
|
664 |
+
# additional_model_input = torch.cat([quantized_bestrq_middle, quantized_bestrq_last, spk_embeds],1)
|
665 |
+
additional_model_input = torch.cat([quantized_bestrq_emb, spk_embeds],1)
|
666 |
+
else:
|
667 |
+
# additional_model_input = torch.cat([quantized_bestrq_middle, quantized_bestrq_last],1)
|
668 |
+
additional_model_input = torch.cat([quantized_bestrq_emb],1)
|
669 |
+
|
670 |
+
temperature = 1.0
|
671 |
+
t_span = torch.linspace(0, 1, num_steps + 1, device=quantized_bestrq_emb.device)
|
672 |
+
latents = self.cfm_wrapper.solve_euler(latents * temperature, latent_mask_input,incontext_latents, incontext_length, t_span, additional_model_input,attention_mask, guidance_scale)
|
673 |
+
|
674 |
+
latents[:,0:incontext_length,:] = incontext_latents[:,0:incontext_length,:]
|
675 |
+
latents = latents.permute(0,2,1).contiguous()
|
676 |
+
latents = self.normfeat.return_sample(latents)
|
677 |
+
# latents = latents.permute(0,2,1).contiguous()
|
678 |
+
return latents
|
679 |
+
|
680 |
+
@torch.no_grad()
|
681 |
+
def inference(self, input_audios, lyric, true_latents, latent_length, additional_feats, guidance_scale=2, num_steps=20,
|
682 |
+
disable_progress=True,layer=5,scenario='start_seg'):
|
683 |
+
codes, embeds, spk_embeds = self.fetch_codes(input_audios, additional_feats,layer)
|
684 |
+
|
685 |
+
latents = self.inference_codes(codes, spk_embeds, true_latents, latent_length, additional_feats, \
|
686 |
+
guidance_scale=guidance_scale, num_steps=num_steps, \
|
687 |
+
disable_progress=disable_progress,scenario=scenario)
|
688 |
+
return latents
|
689 |
+
|
690 |
+
@torch.no_grad()
|
691 |
+
def inference_rtf(self, input_audios, lyric, true_latents, latent_length, additional_feats, guidance_scale=2, num_steps=20,
|
692 |
+
disable_progress=True,layer=5,scenario='start_seg'):
|
693 |
+
codes, embeds, spk_embeds = self.fetch_codes(input_audios, additional_feats,layer)
|
694 |
+
import time
|
695 |
+
start = time.time()
|
696 |
+
latents = self.inference_codes(codes, spk_embeds, true_latents, latent_length, additional_feats, \
|
697 |
+
guidance_scale=guidance_scale, num_steps=num_steps, \
|
698 |
+
disable_progress=disable_progress,scenario=scenario)
|
699 |
+
return latents,time.time()-start
|
700 |
+
|
701 |
+
def prepare_latents(self, batch_size, num_frames, num_channels_latents, dtype, device):
|
702 |
+
divisor = 4
|
703 |
+
shape = (batch_size, num_channels_latents, num_frames, 32)
|
704 |
+
if(num_frames%divisor>0):
|
705 |
+
num_frames = round(num_frames/float(divisor))*divisor
|
706 |
+
shape = (batch_size, num_channels_latents, num_frames, 32)
|
707 |
+
latents = randn_tensor(shape, generator=None, device=device, dtype=dtype)
|
708 |
+
return latents
|
709 |
+
|
710 |
+
|
codeclm/tokenizer/Flow1dVAE/model_2rvq.py
CHANGED
@@ -1,774 +1,774 @@
|
|
1 |
-
import yaml
|
2 |
-
import random
|
3 |
-
import inspect
|
4 |
-
import numpy as np
|
5 |
-
from tqdm import tqdm
|
6 |
-
import typing as tp
|
7 |
-
from abc import ABC
|
8 |
-
|
9 |
-
import torch
|
10 |
-
import torch.nn as nn
|
11 |
-
import torch.nn.functional as F
|
12 |
-
import torchaudio
|
13 |
-
|
14 |
-
from einops import repeat
|
15 |
-
from tools.torch_tools import wav_to_fbank
|
16 |
-
|
17 |
-
import diffusers
|
18 |
-
from diffusers.utils.torch_utils import randn_tensor
|
19 |
-
from diffusers import DDPMScheduler
|
20 |
-
from models.transformer_2d_flow import Transformer2DModel
|
21 |
-
from transformers import AutoFeatureExtractor, Wav2Vec2BertModel,HubertModel
|
22 |
-
# from tools.get_mulan import get_mulan
|
23 |
-
from third_party.wespeaker.extract_embd import XVECModel
|
24 |
-
# from libs.rvq2 import RVQEmbedding
|
25 |
-
from libs.rvq.descript_quantize3_4layer_freezelayer1 import ResidualVectorQuantize
|
26 |
-
|
27 |
-
from models_gpt.models.gpt2_rope2_time_new_correct_mask_noncasual_reflow import GPT2Model
|
28 |
-
from models_gpt.models.gpt2_config import GPT2Config
|
29 |
-
|
30 |
-
from torch.cuda.amp import autocast
|
31 |
-
|
32 |
-
|
33 |
-
from our_MERT_BESTRQ.test import load_model
|
34 |
-
|
35 |
-
class HubertModelWithFinalProj(HubertModel):
|
36 |
-
def __init__(self, config):
|
37 |
-
super().__init__(config)
|
38 |
-
|
39 |
-
# The final projection layer is only used for backward compatibility.
|
40 |
-
# Following https://github.com/auspicious3000/contentvec/issues/6
|
41 |
-
# Remove this layer is necessary to achieve the desired outcome.
|
42 |
-
print("hidden_size:",config.hidden_size)
|
43 |
-
print("classifier_proj_size:",config.classifier_proj_size)
|
44 |
-
self.final_proj = nn.Linear(config.hidden_size, config.classifier_proj_size)
|
45 |
-
|
46 |
-
|
47 |
-
class SampleProcessor(torch.nn.Module):
|
48 |
-
def project_sample(self, x: torch.Tensor):
|
49 |
-
"""Project the original sample to the 'space' where the diffusion will happen."""
|
50 |
-
"""Project back from diffusion space to the actual sample space."""
|
51 |
-
return z
|
52 |
-
|
53 |
-
class Feature1DProcessor(SampleProcessor):
|
54 |
-
def __init__(self, dim: int = 100, power_std = 1., \
|
55 |
-
num_samples: int = 100_000, cal_num_frames: int = 600):
|
56 |
-
super().__init__()
|
57 |
-
|
58 |
-
self.num_samples = num_samples
|
59 |
-
self.dim = dim
|
60 |
-
self.power_std = power_std
|
61 |
-
self.cal_num_frames = cal_num_frames
|
62 |
-
self.register_buffer('counts', torch.zeros(1))
|
63 |
-
self.register_buffer('sum_x', torch.zeros(dim))
|
64 |
-
self.register_buffer('sum_x2', torch.zeros(dim))
|
65 |
-
self.register_buffer('sum_target_x2', torch.zeros(dim))
|
66 |
-
self.counts: torch.Tensor
|
67 |
-
self.sum_x: torch.Tensor
|
68 |
-
self.sum_x2: torch.Tensor
|
69 |
-
|
70 |
-
@property
|
71 |
-
def mean(self):
|
72 |
-
mean = self.sum_x / self.counts
|
73 |
-
if(self.counts < 10):
|
74 |
-
mean = torch.zeros_like(mean)
|
75 |
-
return mean
|
76 |
-
|
77 |
-
@property
|
78 |
-
def std(self):
|
79 |
-
std = (self.sum_x2 / self.counts - self.mean**2).clamp(min=0).sqrt()
|
80 |
-
if(self.counts < 10):
|
81 |
-
std = torch.ones_like(std)
|
82 |
-
return std
|
83 |
-
|
84 |
-
@property
|
85 |
-
def target_std(self):
|
86 |
-
return 1
|
87 |
-
|
88 |
-
def project_sample(self, x: torch.Tensor):
|
89 |
-
assert x.dim() == 3
|
90 |
-
if self.counts.item() < self.num_samples:
|
91 |
-
self.counts += len(x)
|
92 |
-
self.sum_x += x[:,:,0:self.cal_num_frames].mean(dim=(2,)).sum(dim=0)
|
93 |
-
self.sum_x2 += x[:,:,0:self.cal_num_frames].pow(2).mean(dim=(2,)).sum(dim=0)
|
94 |
-
rescale = (self.target_std / self.std.clamp(min=1e-12)) ** self.power_std # same output size
|
95 |
-
x = (x - self.mean.view(1, -1, 1)) * rescale.view(1, -1, 1)
|
96 |
-
return x
|
97 |
-
|
98 |
-
def return_sample(self, x: torch.Tensor):
|
99 |
-
assert x.dim() == 3
|
100 |
-
rescale = (self.std / self.target_std) ** self.power_std
|
101 |
-
# print(rescale, self.mean)
|
102 |
-
x = x * rescale.view(1, -1, 1) + self.mean.view(1, -1, 1)
|
103 |
-
return x
|
104 |
-
|
105 |
-
def pad_or_tunc_tolen(prior_text_encoder_hidden_states, prior_text_mask, prior_prompt_embeds, len_size=77):
|
106 |
-
if(prior_text_encoder_hidden_states.shape[1]<len_size):
|
107 |
-
prior_text_encoder_hidden_states = torch.cat([prior_text_encoder_hidden_states, \
|
108 |
-
torch.zeros(prior_text_mask.shape[0], len_size-prior_text_mask.shape[1], \
|
109 |
-
prior_text_encoder_hidden_states.shape[2], device=prior_text_mask.device, \
|
110 |
-
dtype=prior_text_encoder_hidden_states.dtype)],1)
|
111 |
-
prior_text_mask = torch.cat([prior_text_mask, torch.zeros(prior_text_mask.shape[0], len_size-prior_text_mask.shape[1], device=prior_text_mask.device, dtype=prior_text_mask.dtype)],1)
|
112 |
-
else:
|
113 |
-
prior_text_encoder_hidden_states = prior_text_encoder_hidden_states[:,0:len_size]
|
114 |
-
prior_text_mask = prior_text_mask[:,0:len_size]
|
115 |
-
prior_text_encoder_hidden_states = prior_text_encoder_hidden_states.permute(0,2,1).contiguous()
|
116 |
-
return prior_text_encoder_hidden_states, prior_text_mask, prior_prompt_embeds
|
117 |
-
|
118 |
-
class BASECFM(torch.nn.Module, ABC):
|
119 |
-
def __init__(
|
120 |
-
self,
|
121 |
-
estimator,
|
122 |
-
mlp,
|
123 |
-
ssl_layer
|
124 |
-
):
|
125 |
-
super().__init__()
|
126 |
-
self.sigma_min = 1e-4
|
127 |
-
|
128 |
-
self.estimator = estimator
|
129 |
-
self.mlp = mlp
|
130 |
-
self.ssl_layer = ssl_layer
|
131 |
-
|
132 |
-
@torch.inference_mode()
|
133 |
-
def forward(self, mu, n_timesteps, temperature=1.0):
|
134 |
-
"""Forward diffusion
|
135 |
-
|
136 |
-
Args:
|
137 |
-
mu (torch.Tensor): output of encoder
|
138 |
-
shape: (batch_size, n_channels, mel_timesteps, n_feats)
|
139 |
-
n_timesteps (int): number of diffusion steps
|
140 |
-
temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
|
141 |
-
|
142 |
-
Returns:
|
143 |
-
sample: generated mel-spectrogram
|
144 |
-
shape: (batch_size, n_channels, mel_timesteps, n_feats)
|
145 |
-
"""
|
146 |
-
z = torch.randn_like(mu) * temperature
|
147 |
-
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device)
|
148 |
-
return self.solve_euler(z, t_span=t_span)
|
149 |
-
|
150 |
-
def solve_euler(self, x, latent_mask_input,incontext_x, incontext_length, t_span, mu,attention_mask, guidance_scale):
|
151 |
-
"""
|
152 |
-
Fixed euler solver for ODEs.
|
153 |
-
Args:
|
154 |
-
x (torch.Tensor): random noise
|
155 |
-
t_span (torch.Tensor): n_timesteps interpolated
|
156 |
-
shape: (n_timesteps + 1,)
|
157 |
-
mu (torch.Tensor): output of encoder
|
158 |
-
shape: (batch_size, n_channels, mel_timesteps, n_feats)
|
159 |
-
"""
|
160 |
-
t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
|
161 |
-
noise = x.clone()
|
162 |
-
|
163 |
-
# I am storing this because I can later plot it by putting a debugger here and saving it to a file
|
164 |
-
# Or in future might add like a return_all_steps flag
|
165 |
-
sol = []
|
166 |
-
|
167 |
-
for step in tqdm(range(1, len(t_span))):
|
168 |
-
# print("incontext_x.shape:",incontext_x.shape)
|
169 |
-
# print("noise.shape:",noise.shape)
|
170 |
-
# print("t.shape:",t.shape)
|
171 |
-
x[:,0:incontext_length,:] = (1 - (1 - self.sigma_min) * t) * noise[:,0:incontext_length,:] + t * incontext_x[:,0:incontext_length,:]
|
172 |
-
if(guidance_scale > 1.0):
|
173 |
-
|
174 |
-
model_input = torch.cat([ \
|
175 |
-
torch.cat([latent_mask_input, latent_mask_input], 0), \
|
176 |
-
torch.cat([incontext_x, incontext_x], 0), \
|
177 |
-
torch.cat([torch.zeros_like(mu), mu], 0), \
|
178 |
-
torch.cat([x, x], 0), \
|
179 |
-
], 2)
|
180 |
-
timestep=t.unsqueeze(-1).repeat(2)
|
181 |
-
|
182 |
-
dphi_dt = self.estimator(inputs_embeds=model_input, attention_mask=attention_mask,time_step=timestep).last_hidden_state
|
183 |
-
dphi_dt_uncond, dhpi_dt_cond = dphi_dt.chunk(2,0)
|
184 |
-
dphi_dt = dphi_dt_uncond + guidance_scale * (dhpi_dt_cond - dphi_dt_uncond)
|
185 |
-
else:
|
186 |
-
model_input = torch.cat([latent_mask_input, incontext_x, mu, x], 2)
|
187 |
-
timestep=t.unsqueeze(-1)
|
188 |
-
dphi_dt = self.estimator(inputs_embeds=model_input, attention_mask=attention_mask,time_step=timestep).last_hidden_state
|
189 |
-
|
190 |
-
dphi_dt = dphi_dt[: ,:, -x.shape[2]:]
|
191 |
-
# print("dphi_dt.shape:",dphi_dt.shape)
|
192 |
-
# print("x.shape:",x.shape)
|
193 |
-
|
194 |
-
x = x + dt * dphi_dt
|
195 |
-
t = t + dt
|
196 |
-
sol.append(x)
|
197 |
-
if step < len(t_span) - 1:
|
198 |
-
dt = t_span[step + 1] - t
|
199 |
-
|
200 |
-
return sol[-1]
|
201 |
-
|
202 |
-
def projection_loss(self,hidden_proj, bestrq_emb):
|
203 |
-
bsz = hidden_proj.shape[0]
|
204 |
-
|
205 |
-
hidden_proj_normalized = F.normalize(hidden_proj, dim=-1)
|
206 |
-
bestrq_emb_normalized = F.normalize(bestrq_emb, dim=-1)
|
207 |
-
|
208 |
-
proj_loss = -(hidden_proj_normalized * bestrq_emb_normalized).sum(dim=-1)
|
209 |
-
proj_loss = 1+proj_loss.mean()
|
210 |
-
|
211 |
-
return proj_loss
|
212 |
-
|
213 |
-
def compute_loss(self, x1, mu, latent_masks,attention_mask,wav2vec_embeds, validation_mode=False):
|
214 |
-
"""Computes diffusion loss
|
215 |
-
|
216 |
-
Args:
|
217 |
-
x1 (torch.Tensor): Target
|
218 |
-
shape: (batch_size, n_channels, mel_timesteps, n_feats)
|
219 |
-
mu (torch.Tensor): output of encoder
|
220 |
-
shape: (batch_size, n_channels, mel_timesteps, n_feats)
|
221 |
-
|
222 |
-
Returns:
|
223 |
-
loss: conditional flow matching loss
|
224 |
-
y: conditional flow
|
225 |
-
shape: (batch_size, n_channels, mel_timesteps, n_feats)
|
226 |
-
"""
|
227 |
-
b = mu[0].shape[0]
|
228 |
-
len_x = x1.shape[2]
|
229 |
-
# random timestep
|
230 |
-
if(validation_mode):
|
231 |
-
t = torch.ones([b, 1, 1], device=mu[0].device, dtype=mu[0].dtype) * 0.5
|
232 |
-
else:
|
233 |
-
t = torch.rand([b, 1, 1], device=mu[0].device, dtype=mu[0].dtype)
|
234 |
-
# sample noise p(x_0)
|
235 |
-
z = torch.randn_like(x1)
|
236 |
-
|
237 |
-
y = (1 - (1 - self.sigma_min) * t) * z + t * x1
|
238 |
-
u = x1 - (1 - self.sigma_min) * z
|
239 |
-
# print("y.shape:",y.shape)
|
240 |
-
#self.unet(inputs_embeds=model_input, attention_mask=attention_mask,encoder_hidden_states=text_embedding,encoder_attention_mask=txt_attn_mask,time_step=timesteps).last_hidden_state
|
241 |
-
model_input = torch.cat([*mu,y], 2)
|
242 |
-
t=t.squeeze(-1).squeeze(-1)
|
243 |
-
# print("model_input.shape:",model_input.shape)
|
244 |
-
# print("attention_mask.shape:",attention_mask.shape)
|
245 |
-
out = self.estimator(inputs_embeds=model_input, attention_mask=attention_mask,time_step=t,output_hidden_states=True)
|
246 |
-
hidden_layer = out.hidden_states[self.ssl_layer]
|
247 |
-
hidden_proj = self.mlp(hidden_layer)
|
248 |
-
# print("hidden_proj.shape:",hidden_proj.shape)
|
249 |
-
# print("mert_emb.shape:",mert_emb.shape)
|
250 |
-
# exit()
|
251 |
-
|
252 |
-
|
253 |
-
out = out.last_hidden_state
|
254 |
-
|
255 |
-
out=out[:,:,-len_x:]
|
256 |
-
# out=self.proj_out(out)
|
257 |
-
|
258 |
-
weight = (latent_masks > 1.5).unsqueeze(-1).repeat(1, 1, out.shape[-1]).float() + (latent_masks < 0.5).unsqueeze(-1).repeat(1, 1, out.shape[-1]).float() * 0.01
|
259 |
-
# print("out.shape",out.shape)
|
260 |
-
# print("u.shape",u.shape)
|
261 |
-
loss_re = F.mse_loss(out * weight, u * weight, reduction="sum") / weight.sum()
|
262 |
-
# print("hidden_proj.shape:",hidden_proj.shape)
|
263 |
-
# print("wav2vec_embeds.shape:",wav2vec_embeds.shape)
|
264 |
-
loss_cos = self.projection_loss(hidden_proj, wav2vec_embeds)
|
265 |
-
loss = loss_re + loss_cos * 0.5
|
266 |
-
# print("loss_cos:",loss_cos,loss_cos.device)
|
267 |
-
print("loss:",loss,loss.device)
|
268 |
-
# exit()
|
269 |
-
return loss, loss_re, loss_cos
|
270 |
-
|
271 |
-
class PromptCondAudioDiffusion(nn.Module):
|
272 |
-
def __init__(
|
273 |
-
self,
|
274 |
-
num_channels,
|
275 |
-
unet_model_name=None,
|
276 |
-
unet_model_config_path=None,
|
277 |
-
snr_gamma=None,
|
278 |
-
hubert_layer=None,
|
279 |
-
ssl_layer=None,
|
280 |
-
uncondition=True,
|
281 |
-
out_paint=False,
|
282 |
-
):
|
283 |
-
super().__init__()
|
284 |
-
|
285 |
-
assert unet_model_name is not None or unet_model_config_path is not None, "Either UNet pretrain model name or a config file path is required"
|
286 |
-
|
287 |
-
self.unet_model_name = unet_model_name
|
288 |
-
self.unet_model_config_path = unet_model_config_path
|
289 |
-
self.snr_gamma = snr_gamma
|
290 |
-
self.uncondition = uncondition
|
291 |
-
self.num_channels = num_channels
|
292 |
-
self.hubert_layer = hubert_layer
|
293 |
-
self.ssl_layer = ssl_layer
|
294 |
-
|
295 |
-
# https://huggingface.co/docs/diffusers/v0.14.0/en/api/schedulers/overview
|
296 |
-
self.normfeat = Feature1DProcessor(dim=64)
|
297 |
-
|
298 |
-
self.sample_rate = 48000
|
299 |
-
self.num_samples_perseg = self.sample_rate * 20 // 1000
|
300 |
-
self.rsp48toclap = torchaudio.transforms.Resample(48000, 24000)
|
301 |
-
self.rsq48towav2vec = torchaudio.transforms.Resample(48000, 16000)
|
302 |
-
# self.wav2vec = Wav2Vec2BertModel.from_pretrained("facebook/w2v-bert-2.0", trust_remote_code=True)
|
303 |
-
# self.wav2vec_processor = AutoFeatureExtractor.from_pretrained("facebook/w2v-bert-2.0", trust_remote_code=True)
|
304 |
-
self.bestrq = load_model(
|
305 |
-
model_dir='path/to/our-MERT/mert_fairseq',
|
306 |
-
checkpoint_dir='checkpoint-120000.pt',
|
307 |
-
)
|
308 |
-
self.rsq48tobestrq = torchaudio.transforms.Resample(48000, 24000)
|
309 |
-
self.rsq48tohubert = torchaudio.transforms.Resample(48000, 16000)
|
310 |
-
for v in self.bestrq.parameters():v.requires_grad = False
|
311 |
-
self.rvq_bestrq_emb = ResidualVectorQuantize(input_dim = 1024, n_codebooks = 2, codebook_size = 16_384, codebook_dim = 32, quantizer_dropout = 0.0, stale_tolerance=200)
|
312 |
-
# for v in self.rvq_bestrq_emb.parameters():
|
313 |
-
# print(v)
|
314 |
-
freeze_parameters='quantizers.0'
|
315 |
-
for name, param in self.rvq_bestrq_emb.named_parameters():
|
316 |
-
if freeze_parameters in name:
|
317 |
-
param.requires_grad = False
|
318 |
-
print("Freezing RVQ parameters:", name)
|
319 |
-
self.hubert = HubertModelWithFinalProj.from_pretrained("huggingface_cache/models--lengyue233--content-vec-best/snapshots/c0b9ba13db21beaa4053faae94c102ebe326fd68")
|
320 |
-
for v in self.hubert.parameters():v.requires_grad = False
|
321 |
-
self.zero_cond_embedding1 = nn.Parameter(torch.randn(32*32,))
|
322 |
-
# self.xvecmodel = XVECModel()
|
323 |
-
config = GPT2Config(n_positions=1000,n_layer=39,n_head=30,n_embd=1200)
|
324 |
-
unet = GPT2Model(config)
|
325 |
-
mlp = nn.Sequential(
|
326 |
-
nn.Linear(1200, 1024),
|
327 |
-
nn.SiLU(),
|
328 |
-
nn.Linear(1024, 1024),
|
329 |
-
nn.SiLU(),
|
330 |
-
nn.Linear(1024, 768)
|
331 |
-
)
|
332 |
-
self.set_from = "random"
|
333 |
-
self.cfm_wrapper = BASECFM(unet, mlp,self.ssl_layer)
|
334 |
-
self.mask_emb = torch.nn.Embedding(3, 48)
|
335 |
-
print("Transformer initialized from pretrain.")
|
336 |
-
torch.cuda.empty_cache()
|
337 |
-
# self.unet.set_attn_processor(AttnProcessor2_0())
|
338 |
-
# self.unet.set_use_memory_efficient_attention_xformers(True)
|
339 |
-
|
340 |
-
# self.start_embedding = nn.Parameter(torch.randn(1,1024))
|
341 |
-
# self.end_embedding = nn.Parameter(torch.randn(1,1024))
|
342 |
-
|
343 |
-
def compute_snr(self, timesteps):
|
344 |
-
"""
|
345 |
-
Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
|
346 |
-
"""
|
347 |
-
alphas_cumprod = self.noise_scheduler.alphas_cumprod
|
348 |
-
sqrt_alphas_cumprod = alphas_cumprod**0.5
|
349 |
-
sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
|
350 |
-
|
351 |
-
# Expand the tensors.
|
352 |
-
# Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026
|
353 |
-
sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
|
354 |
-
while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape):
|
355 |
-
sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
|
356 |
-
alpha = sqrt_alphas_cumprod.expand(timesteps.shape)
|
357 |
-
|
358 |
-
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
|
359 |
-
while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape):
|
360 |
-
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None]
|
361 |
-
sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape)
|
362 |
-
|
363 |
-
# Compute SNR.
|
364 |
-
snr = (alpha / sigma) ** 2
|
365 |
-
return snr
|
366 |
-
|
367 |
-
def preprocess_audio(self, input_audios, threshold=0.9):
|
368 |
-
assert len(input_audios.shape) == 2, input_audios.shape
|
369 |
-
norm_value = torch.ones_like(input_audios[:,0])
|
370 |
-
max_volume = input_audios.abs().max(dim=-1)[0]
|
371 |
-
norm_value[max_volume>threshold] = max_volume[max_volume>threshold] / threshold
|
372 |
-
return input_audios/norm_value.unsqueeze(-1)
|
373 |
-
|
374 |
-
def extract_wav2vec_embeds(self, input_audios,output_len):
|
375 |
-
wav2vec_stride = 2
|
376 |
-
|
377 |
-
wav2vec_embeds = self.hubert(self.rsq48tohubert(input_audios), output_hidden_states=True).hidden_states # 1, 4096, 1024
|
378 |
-
# print(wav2vec_embeds)
|
379 |
-
# print("audio.shape:",input_audios.shape)
|
380 |
-
wav2vec_embeds_last=wav2vec_embeds[self.hubert_layer]
|
381 |
-
# print("wav2vec_embeds_last.shape:",wav2vec_embeds_last.shape)
|
382 |
-
wav2vec_embeds_last=torch.nn.functional.interpolate(wav2vec_embeds_last.permute(0, 2, 1), size=output_len, mode='linear', align_corners=False).permute(0, 2, 1)
|
383 |
-
return wav2vec_embeds_last
|
384 |
-
|
385 |
-
def extract_mert_embeds(self, input_audios):
|
386 |
-
prompt_stride = 3
|
387 |
-
inputs = self.clap_embd_extractor.mulan.audio.processor(self.rsp48toclap(input_audios), sampling_rate=self.clap_embd_extractor.mulan.audio.sr, return_tensors="pt")
|
388 |
-
input_values = inputs['input_values'].squeeze(0).to(input_audios.device, dtype = input_audios.dtype)
|
389 |
-
prompt_embeds = self.clap_embd_extractor.mulan.audio.model(input_values, output_hidden_states=True).hidden_states # batch_size, Time steps, 1024
|
390 |
-
mert_emb= prompt_embeds[-1]
|
391 |
-
mert_emb = torch.nn.functional.interpolate(mert_emb.permute(0, 2, 1), size=500, mode='linear', align_corners=False).permute(0, 2, 1)
|
392 |
-
|
393 |
-
return mert_emb
|
394 |
-
|
395 |
-
def extract_bestrq_embeds(self, input_audio_0,input_audio_1,layer):
|
396 |
-
self.bestrq.eval()
|
397 |
-
# print("audio shape:",input_audio_0.shape)
|
398 |
-
input_wav_mean = (input_audio_0 + input_audio_1) / 2.0
|
399 |
-
# print("input_wav_mean.shape:",input_wav_mean.shape)
|
400 |
-
# input_wav_mean = torch.randn(2,1720320*2).to(input_audio_0.device)
|
401 |
-
input_wav_mean = self.bestrq(self.rsq48tobestrq(input_wav_mean), features_only = True)
|
402 |
-
layer_results = input_wav_mean['layer_results']
|
403 |
-
# print("layer_results.shape:",layer_results[layer].shape)
|
404 |
-
bestrq_emb = layer_results[layer]
|
405 |
-
bestrq_emb = bestrq_emb.permute(0,2,1).contiguous()
|
406 |
-
#[b,t,1024] t=t/960
|
407 |
-
#35.84s->batch,896,1024
|
408 |
-
return bestrq_emb
|
409 |
-
|
410 |
-
|
411 |
-
def extract_spk_embeds(self, input_audios):
|
412 |
-
spk_embeds = self.xvecmodel(self.rsq48towav2vec(input_audios))
|
413 |
-
spk_embeds = self.spk_linear(spk_embeds).reshape(spk_embeds.shape[0], 16, 1, 32)
|
414 |
-
return spk_embeds
|
415 |
-
|
416 |
-
def extract_lyric_feats(self, lyric):
|
417 |
-
with torch.no_grad():
|
418 |
-
try:
|
419 |
-
text_encoder_hidden_states, text_mask, text_prompt_embeds = self.clap_embd_extractor(texts = lyric, return_one=False)
|
420 |
-
except:
|
421 |
-
text_encoder_hidden_states, text_mask, text_prompt_embeds = self.clap_embd_extractor(texts = [""] * len(lyric), return_one=False)
|
422 |
-
text_encoder_hidden_states = text_encoder_hidden_states.to(self.device)
|
423 |
-
text_mask = text_mask.to(self.device)
|
424 |
-
text_encoder_hidden_states, text_mask, text_prompt_embeds = \
|
425 |
-
pad_or_tunc_tolen(text_encoder_hidden_states, text_mask, text_prompt_embeds)
|
426 |
-
text_encoder_hidden_states = text_encoder_hidden_states.permute(0,2,1).contiguous()
|
427 |
-
return text_encoder_hidden_states, text_mask
|
428 |
-
|
429 |
-
def extract_energy_bar(self, input_audios):
|
430 |
-
if(input_audios.shape[-1] % self.num_samples_perseg > 0):
|
431 |
-
energy_bar = input_audios[:,:-1 * (input_audios.shape[-1] % self.num_samples_perseg)].reshape(input_audios.shape[0],-1,self.num_samples_perseg)
|
432 |
-
else:
|
433 |
-
energy_bar = input_audios.reshape(input_audios.shape[0],-1,self.num_samples_perseg)
|
434 |
-
energy_bar = (energy_bar.pow(2.0).mean(-1).sqrt() + 1e-6).log10() * 20 # B T
|
435 |
-
energy_bar = (energy_bar / 2.0 + 16).clamp(0,16).int()
|
436 |
-
energy_embedding = self.energy_embedding(energy_bar)
|
437 |
-
energy_embedding = energy_embedding.view(energy_embedding.shape[0], energy_embedding.shape[1] // 2, 2, 32).reshape(energy_embedding.shape[0], energy_embedding.shape[1] // 2, 64).permute(0,2,1) # b 128 t
|
438 |
-
return energy_embedding
|
439 |
-
|
440 |
-
def forward(self, input_audios, lyric, latents, latent_masks, validation_mode=False, \
|
441 |
-
additional_feats = ['spk', 'lyric'], \
|
442 |
-
train_rvq=True, train_ssl=False,layer=5):
|
443 |
-
if not hasattr(self,"device"):
|
444 |
-
self.device = input_audios.device
|
445 |
-
if not hasattr(self,"dtype"):
|
446 |
-
self.dtype = input_audios.dtype
|
447 |
-
device = self.device
|
448 |
-
input_audio_0 = input_audios[:,0,:]
|
449 |
-
input_audio_1 = input_audios[:,1,:]
|
450 |
-
input_audio_0 = self.preprocess_audio(input_audio_0)
|
451 |
-
input_audio_1 = self.preprocess_audio(input_audio_1)
|
452 |
-
input_audios_wav2vec = (input_audio_0 + input_audio_1) / 2.0
|
453 |
-
# energy_embedding = self.extract_energy_bar(input_audios)
|
454 |
-
# print("energy_embedding.shape:",energy_embedding.shape)
|
455 |
-
# with autocast(enabled=False):
|
456 |
-
if(train_ssl):
|
457 |
-
self.wav2vec.train()
|
458 |
-
wav2vec_embeds = self.extract_wav2vec_embeds(input_audios)
|
459 |
-
self.clap_embd_extractor.train()
|
460 |
-
prompt_embeds = self.extract_mert_embeds(input_audios)
|
461 |
-
if('spk' in additional_feats):
|
462 |
-
self.xvecmodel.train()
|
463 |
-
spk_embeds = self.extract_spk_embeds(input_audios).repeat(1,1,prompt_embeds.shape[-1]//2,1)
|
464 |
-
else:
|
465 |
-
with torch.no_grad():
|
466 |
-
with autocast(enabled=False):
|
467 |
-
bestrq_emb = self.extract_bestrq_embeds(input_audio_0,input_audio_1,layer)
|
468 |
-
# mert_emb = self.extract_mert_embeds(input_audios_mert)
|
469 |
-
|
470 |
-
wav2vec_embeds = self.extract_wav2vec_embeds(input_audios_wav2vec,bestrq_emb.shape[2])
|
471 |
-
|
472 |
-
bestrq_emb = bestrq_emb.detach()
|
473 |
-
if('lyric' in additional_feats):
|
474 |
-
text_encoder_hidden_states, text_mask = self.extract_lyric_feats(lyric)
|
475 |
-
else:
|
476 |
-
text_encoder_hidden_states, text_mask = None, None
|
477 |
-
|
478 |
-
|
479 |
-
if(train_rvq):
|
480 |
-
random_num=random.random()
|
481 |
-
if(random_num<0.6):
|
482 |
-
rvq_layer = 1
|
483 |
-
elif(random_num<0.8):
|
484 |
-
rvq_layer = 2
|
485 |
-
else:
|
486 |
-
rvq_layer = 4
|
487 |
-
quantized_bestrq_emb, _, _, commitment_loss_bestrq_emb, codebook_loss_bestrq_emb,_ = self.rvq_bestrq_emb(bestrq_emb,n_quantizers=rvq_layer) # b,d,t
|
488 |
-
else:
|
489 |
-
bestrq_emb = bestrq_emb.float()
|
490 |
-
self.rvq_bestrq_emb.eval()
|
491 |
-
# with autocast(enabled=False):
|
492 |
-
quantized_bestrq_emb, _, _, commitment_loss_bestrq_emb, codebook_loss_bestrq_emb,_ = self.rvq_bestrq_emb(bestrq_emb) # b,d,t
|
493 |
-
commitment_loss_bestrq_emb = commitment_loss_bestrq_emb.detach()
|
494 |
-
codebook_loss_bestrq_emb = codebook_loss_bestrq_emb.detach()
|
495 |
-
quantized_bestrq_emb = quantized_bestrq_emb.detach()
|
496 |
-
|
497 |
-
commitment_loss = commitment_loss_bestrq_emb
|
498 |
-
codebook_loss = codebook_loss_bestrq_emb
|
499 |
-
|
500 |
-
|
501 |
-
alpha=1
|
502 |
-
quantized_bestrq_emb = quantized_bestrq_emb * alpha + bestrq_emb * (1-alpha)
|
503 |
-
|
504 |
-
# print("quantized_bestrq_emb.shape:",quantized_bestrq_emb.shape)
|
505 |
-
# print("latent_masks.shape:",latent_masks.shape)
|
506 |
-
# quantized_bestrq_emb = torch.nn.functional.interpolate(quantized_bestrq_emb, size=(int(quantized_bestrq_emb.shape[-1]/999*937),), mode='linear', align_corners=True)
|
507 |
-
|
508 |
-
|
509 |
-
|
510 |
-
scenario = np.random.choice(['start_seg', 'other_seg'])
|
511 |
-
if(scenario == 'other_seg'):
|
512 |
-
for binx in range(input_audios.shape[0]):
|
513 |
-
# latent_masks[binx,0:64] = 1
|
514 |
-
latent_masks[binx,0:random.randint(64,128)] = 1
|
515 |
-
quantized_bestrq_emb = quantized_bestrq_emb.permute(0,2,1).contiguous()
|
516 |
-
# print("quantized_bestrq_emb.shape:",quantized_bestrq_emb.shape)
|
517 |
-
# print("quantized_bestrq_emb1.shape:",quantized_bestrq_emb.shape)
|
518 |
-
# print("latent_masks.shape:",latent_masks.shape)
|
519 |
-
quantized_bestrq_emb = (latent_masks > 0.5).unsqueeze(-1) * quantized_bestrq_emb \
|
520 |
-
+ (latent_masks < 0.5).unsqueeze(-1) * self.zero_cond_embedding1.reshape(1,1,1024)
|
521 |
-
|
522 |
-
|
523 |
-
|
524 |
-
|
525 |
-
if self.uncondition:
|
526 |
-
mask_indices = [k for k in range(quantized_bestrq_emb.shape[0]) if random.random() < 0.1]
|
527 |
-
if len(mask_indices) > 0:
|
528 |
-
quantized_bestrq_emb[mask_indices] = 0
|
529 |
-
# print("latents.shape:",latents.shape)
|
530 |
-
latents = latents.permute(0,2,1).contiguous()
|
531 |
-
latents = self.normfeat.project_sample(latents)
|
532 |
-
latents = latents.permute(0,2,1).contiguous()
|
533 |
-
incontext_latents = latents * ((latent_masks > 0.5) * (latent_masks < 1.5)).unsqueeze(-1).float()
|
534 |
-
attention_mask=(latent_masks > 0.5)
|
535 |
-
B, L = attention_mask.size()
|
536 |
-
attention_mask = attention_mask.view(B, 1, L)
|
537 |
-
attention_mask = attention_mask * attention_mask.transpose(-1, -2)
|
538 |
-
attention_mask = attention_mask.unsqueeze(1)
|
539 |
-
# print("incontext_latents.shape:",incontext_latents.shape)
|
540 |
-
# print("quantized_bestrq_emb.shape:",quantized_bestrq_emb.shape)
|
541 |
-
latent_mask_input = self.mask_emb(latent_masks)
|
542 |
-
#64+48+64+1024
|
543 |
-
loss,loss_re, loss_cos = self.cfm_wrapper.compute_loss(latents, [latent_mask_input,incontext_latents, quantized_bestrq_emb], latent_masks,attention_mask,wav2vec_embeds, validation_mode=validation_mode)
|
544 |
-
return loss,loss_re, loss_cos, commitment_loss.mean(), codebook_loss.mean()
|
545 |
-
|
546 |
-
def init_device_dtype(self, device, dtype):
|
547 |
-
self.device = device
|
548 |
-
self.dtype = dtype
|
549 |
-
|
550 |
-
@torch.no_grad()
|
551 |
-
def fetch_codes(self, input_audios, additional_feats,layer,rvq_num=1):
|
552 |
-
input_audio_0 = input_audios[[0],:]
|
553 |
-
input_audio_1 = input_audios[[1],:]
|
554 |
-
input_audio_0 = self.preprocess_audio(input_audio_0)
|
555 |
-
input_audio_1 = self.preprocess_audio(input_audio_1)
|
556 |
-
|
557 |
-
self.bestrq.eval()
|
558 |
-
|
559 |
-
# bestrq_middle,bestrq_last = self.extract_bestrq_embeds(input_audios)
|
560 |
-
# bestrq_middle = bestrq_middle.detach()
|
561 |
-
# bestrq_last = bestrq_last.detach()
|
562 |
-
bestrq_emb = self.extract_bestrq_embeds(input_audio_0,input_audio_1,layer)
|
563 |
-
bestrq_emb = bestrq_emb.detach()
|
564 |
-
|
565 |
-
# self.rvq_bestrq_middle.eval()
|
566 |
-
# quantized_bestrq_middle, codes_bestrq_middle, *_ = self.rvq_bestrq_middle(bestrq_middle) # b,d,t
|
567 |
-
# self.rvq_bestrq_last.eval()
|
568 |
-
# quantized_bestrq_last, codes_bestrq_last, *_ = self.rvq_bestrq_last(bestrq_last) # b,d,t
|
569 |
-
|
570 |
-
self.rvq_bestrq_emb.eval()
|
571 |
-
quantized_bestrq_emb, codes_bestrq_emb, *_ = self.rvq_bestrq_emb(bestrq_emb)
|
572 |
-
codes_bestrq_emb = codes_bestrq_emb[:,:rvq_num,:]
|
573 |
-
# print("codes_bestrq_emb.shape:",codes_bestrq_emb.shape)
|
574 |
-
# exit()
|
575 |
-
|
576 |
-
|
577 |
-
if('spk' in additional_feats):
|
578 |
-
self.xvecmodel.eval()
|
579 |
-
spk_embeds = self.extract_spk_embeds(input_audios)
|
580 |
-
else:
|
581 |
-
spk_embeds = None
|
582 |
-
|
583 |
-
# return [codes_prompt, codes_wav2vec], [prompt_embeds, wav2vec_embeds], spk_embeds
|
584 |
-
# return [codes_prompt_7, codes_prompt_13, codes_prompt_20, codes_wav2vec_half, codes_wav2vec_last], [prompt_embeds_7, prompt_embeds_13, prompt_embeds_20, wav2vec_embeds_half, wav2vec_embeds_last], spk_embeds
|
585 |
-
# return [codes_bestrq_middle, codes_bestrq_last], [bestrq_middle, bestrq_last], spk_embeds
|
586 |
-
return [codes_bestrq_emb], [bestrq_emb], spk_embeds
|
587 |
-
# return [codes_prompt_13, codes_wav2vec_last], [prompt_embeds_13, wav2vec_embeds_last], spk_embeds
|
588 |
-
|
589 |
-
@torch.no_grad()
|
590 |
-
def fetch_codes_batch(self, input_audios, additional_feats,layer,rvq_num=1):
|
591 |
-
input_audio_0 = input_audios[:,0,:]
|
592 |
-
input_audio_1 = input_audios[:,1,:]
|
593 |
-
input_audio_0 = self.preprocess_audio(input_audio_0)
|
594 |
-
input_audio_1 = self.preprocess_audio(input_audio_1)
|
595 |
-
|
596 |
-
self.bestrq.eval()
|
597 |
-
|
598 |
-
# bestrq_middle,bestrq_last = self.extract_bestrq_embeds(input_audios)
|
599 |
-
# bestrq_middle = bestrq_middle.detach()
|
600 |
-
# bestrq_last = bestrq_last.detach()
|
601 |
-
bestrq_emb = self.extract_bestrq_embeds(input_audio_0,input_audio_1,layer)
|
602 |
-
bestrq_emb = bestrq_emb.detach()
|
603 |
-
|
604 |
-
# self.rvq_bestrq_middle.eval()
|
605 |
-
# quantized_bestrq_middle, codes_bestrq_middle, *_ = self.rvq_bestrq_middle(bestrq_middle) # b,d,t
|
606 |
-
# self.rvq_bestrq_last.eval()
|
607 |
-
# quantized_bestrq_last, codes_bestrq_last, *_ = self.rvq_bestrq_last(bestrq_last) # b,d,t
|
608 |
-
|
609 |
-
self.rvq_bestrq_emb.eval()
|
610 |
-
quantized_bestrq_emb, codes_bestrq_emb, *_ = self.rvq_bestrq_emb(bestrq_emb)
|
611 |
-
# print("codes_bestrq_emb.shape:",codes_bestrq_emb.shape)
|
612 |
-
codes_bestrq_emb = codes_bestrq_emb[:,:rvq_num,:]
|
613 |
-
# print("codes_bestrq_emb.shape:",codes_bestrq_emb.shape)
|
614 |
-
# exit()
|
615 |
-
|
616 |
-
|
617 |
-
if('spk' in additional_feats):
|
618 |
-
self.xvecmodel.eval()
|
619 |
-
spk_embeds = self.extract_spk_embeds(input_audios)
|
620 |
-
else:
|
621 |
-
spk_embeds = None
|
622 |
-
|
623 |
-
# return [codes_prompt, codes_wav2vec], [prompt_embeds, wav2vec_embeds], spk_embeds
|
624 |
-
# return [codes_prompt_7, codes_prompt_13, codes_prompt_20, codes_wav2vec_half, codes_wav2vec_last], [prompt_embeds_7, prompt_embeds_13, prompt_embeds_20, wav2vec_embeds_half, wav2vec_embeds_last], spk_embeds
|
625 |
-
# return [codes_bestrq_middle, codes_bestrq_last], [bestrq_middle, bestrq_last], spk_embeds
|
626 |
-
return [codes_bestrq_emb], [bestrq_emb], spk_embeds
|
627 |
-
# return [codes_prompt_13, codes_wav2vec_last], [prompt_embeds_13, wav2vec_embeds_last], spk_embeds
|
628 |
-
|
629 |
-
@torch.no_grad()
|
630 |
-
def fetch_codes_batch_ds(self, input_audios, additional_feats, layer, rvq_num=1, ds=250):
|
631 |
-
input_audio_0 = input_audios[:,0,:]
|
632 |
-
input_audio_1 = input_audios[:,1,:]
|
633 |
-
input_audio_0 = self.preprocess_audio(input_audio_0)
|
634 |
-
input_audio_1 = self.preprocess_audio(input_audio_1)
|
635 |
-
|
636 |
-
self.bestrq.eval()
|
637 |
-
|
638 |
-
# bestrq_middle,bestrq_last = self.extract_bestrq_embeds(input_audios)
|
639 |
-
# bestrq_middle = bestrq_middle.detach()
|
640 |
-
# bestrq_last = bestrq_last.detach()
|
641 |
-
bestrq_emb = self.extract_bestrq_embeds(input_audio_0,input_audio_1,layer)
|
642 |
-
bestrq_emb = bestrq_emb.detach()
|
643 |
-
|
644 |
-
# self.rvq_bestrq_middle.eval()
|
645 |
-
# quantized_bestrq_middle, codes_bestrq_middle, *_ = self.rvq_bestrq_middle(bestrq_middle) # b,d,t
|
646 |
-
# self.rvq_bestrq_last.eval()
|
647 |
-
# quantized_bestrq_last, codes_bestrq_last, *_ = self.rvq_bestrq_last(bestrq_last) # b,d,t
|
648 |
-
|
649 |
-
self.rvq_bestrq_emb.eval()
|
650 |
-
bestrq_emb = torch.nn.functional.avg_pool1d(bestrq_emb, kernel_size=ds, stride=ds)
|
651 |
-
quantized_bestrq_emb, codes_bestrq_emb, *_ = self.rvq_bestrq_emb(bestrq_emb)
|
652 |
-
# print("codes_bestrq_emb.shape:",codes_bestrq_emb.shape)
|
653 |
-
codes_bestrq_emb = codes_bestrq_emb[:,:rvq_num,:]
|
654 |
-
# print("codes_bestrq_emb.shape:",codes_bestrq_emb.shape)
|
655 |
-
# exit()
|
656 |
-
|
657 |
-
|
658 |
-
if('spk' in additional_feats):
|
659 |
-
self.xvecmodel.eval()
|
660 |
-
spk_embeds = self.extract_spk_embeds(input_audios)
|
661 |
-
else:
|
662 |
-
spk_embeds = None
|
663 |
-
|
664 |
-
# return [codes_prompt, codes_wav2vec], [prompt_embeds, wav2vec_embeds], spk_embeds
|
665 |
-
# return [codes_prompt_7, codes_prompt_13, codes_prompt_20, codes_wav2vec_half, codes_wav2vec_last], [prompt_embeds_7, prompt_embeds_13, prompt_embeds_20, wav2vec_embeds_half, wav2vec_embeds_last], spk_embeds
|
666 |
-
# return [codes_bestrq_middle, codes_bestrq_last], [bestrq_middle, bestrq_last], spk_embeds
|
667 |
-
return [codes_bestrq_emb], [bestrq_emb], spk_embeds
|
668 |
-
# return [codes_prompt_13, codes_wav2vec_last], [prompt_embeds_13, wav2vec_embeds_last], spk_embeds
|
669 |
-
|
670 |
-
@torch.no_grad()
|
671 |
-
def inference_codes(self, codes, spk_embeds, true_latents, latent_length, additional_feats, incontext_length=127,
|
672 |
-
guidance_scale=2, num_steps=20,
|
673 |
-
disable_progress=True, scenario='start_seg'):
|
674 |
-
classifier_free_guidance = guidance_scale > 1.0
|
675 |
-
device = self.device
|
676 |
-
dtype = self.dtype
|
677 |
-
# codes_bestrq_middle, codes_bestrq_last = codes
|
678 |
-
codes_bestrq_emb = codes[0]
|
679 |
-
|
680 |
-
|
681 |
-
batch_size = codes_bestrq_emb.shape[0]
|
682 |
-
|
683 |
-
|
684 |
-
quantized_bestrq_emb,_,_=self.rvq_bestrq_emb.from_codes(codes_bestrq_emb)
|
685 |
-
# quantized_bestrq_emb = torch.nn.functional.interpolate(quantized_bestrq_emb, size=(int(quantized_bestrq_emb.shape[-1]/999*937),), mode='linear', align_corners=True)
|
686 |
-
quantized_bestrq_emb = quantized_bestrq_emb.permute(0,2,1).contiguous()
|
687 |
-
print("quantized_bestrq_emb.shape:",quantized_bestrq_emb.shape)
|
688 |
-
# quantized_bestrq_emb = torch.nn.functional.interpolate(quantized_bestrq_emb, size=(int(quantized_bestrq_emb.shape[-1]/999*937),), mode='linear', align_corners=True)
|
689 |
-
|
690 |
-
|
691 |
-
|
692 |
-
|
693 |
-
if('spk' in additional_feats):
|
694 |
-
spk_embeds = spk_embeds.repeat(1,1,quantized_bestrq_emb.shape[-2],1).detach()
|
695 |
-
|
696 |
-
num_frames = quantized_bestrq_emb.shape[1]
|
697 |
-
|
698 |
-
num_channels_latents = self.num_channels
|
699 |
-
shape = (batch_size, num_frames, 64)
|
700 |
-
latents = randn_tensor(shape, generator=None, device=device, dtype=dtype)
|
701 |
-
|
702 |
-
|
703 |
-
|
704 |
-
latent_masks = torch.zeros(latents.shape[0], latents.shape[1], dtype=torch.int64, device=latents.device)
|
705 |
-
latent_masks[:,0:latent_length] = 2
|
706 |
-
if(scenario=='other_seg'):
|
707 |
-
latent_masks[:,0:incontext_length] = 1
|
708 |
-
|
709 |
-
|
710 |
-
|
711 |
-
quantized_bestrq_emb = (latent_masks > 0.5).unsqueeze(-1) * quantized_bestrq_emb \
|
712 |
-
+ (latent_masks < 0.5).unsqueeze(-1) * self.zero_cond_embedding1.reshape(1,1,1024)
|
713 |
-
true_latents = true_latents.permute(0,2,1).contiguous()
|
714 |
-
true_latents = self.normfeat.project_sample(true_latents)
|
715 |
-
true_latents = true_latents.permute(0,2,1).contiguous()
|
716 |
-
incontext_latents = true_latents * ((latent_masks > 0.5) * (latent_masks < 1.5)).unsqueeze(-1).float()
|
717 |
-
incontext_length = ((latent_masks > 0.5) * (latent_masks < 1.5)).sum(-1)[0]
|
718 |
-
|
719 |
-
|
720 |
-
attention_mask=(latent_masks > 0.5)
|
721 |
-
B, L = attention_mask.size()
|
722 |
-
attention_mask = attention_mask.view(B, 1, L)
|
723 |
-
attention_mask = attention_mask * attention_mask.transpose(-1, -2)
|
724 |
-
attention_mask = attention_mask.unsqueeze(1)
|
725 |
-
latent_mask_input = self.mask_emb(latent_masks)
|
726 |
-
|
727 |
-
if('spk' in additional_feats):
|
728 |
-
# additional_model_input = torch.cat([quantized_bestrq_middle, quantized_bestrq_last, spk_embeds],1)
|
729 |
-
additional_model_input = torch.cat([quantized_bestrq_emb, spk_embeds],1)
|
730 |
-
else:
|
731 |
-
# additional_model_input = torch.cat([quantized_bestrq_middle, quantized_bestrq_last],1)
|
732 |
-
additional_model_input = torch.cat([quantized_bestrq_emb],1)
|
733 |
-
|
734 |
-
temperature = 1.0
|
735 |
-
t_span = torch.linspace(0, 1, num_steps + 1, device=quantized_bestrq_emb.device)
|
736 |
-
latents = self.cfm_wrapper.solve_euler(latents * temperature, latent_mask_input,incontext_latents, incontext_length, t_span, additional_model_input,attention_mask, guidance_scale)
|
737 |
-
|
738 |
-
latents[:,0:incontext_length,:] = incontext_latents[:,0:incontext_length,:]
|
739 |
-
latents = latents.permute(0,2,1).contiguous()
|
740 |
-
latents = self.normfeat.return_sample(latents)
|
741 |
-
# latents = latents.permute(0,2,1).contiguous()
|
742 |
-
return latents
|
743 |
-
|
744 |
-
@torch.no_grad()
|
745 |
-
def inference(self, input_audios, lyric, true_latents, latent_length, additional_feats, guidance_scale=2, num_steps=20,
|
746 |
-
disable_progress=True,layer=5,scenario='start_seg',rvq_num=1):
|
747 |
-
codes, embeds, spk_embeds = self.fetch_codes(input_audios, additional_feats,layer,rvq_num)
|
748 |
-
|
749 |
-
latents = self.inference_codes(codes, spk_embeds, true_latents, latent_length, additional_feats, \
|
750 |
-
guidance_scale=guidance_scale, num_steps=num_steps, \
|
751 |
-
disable_progress=disable_progress,scenario=scenario)
|
752 |
-
return latents
|
753 |
-
|
754 |
-
@torch.no_grad()
|
755 |
-
def inference_rtf(self, input_audios, lyric, true_latents, latent_length, additional_feats, guidance_scale=2, num_steps=20,
|
756 |
-
disable_progress=True,layer=5,scenario='start_seg'):
|
757 |
-
codes, embeds, spk_embeds = self.fetch_codes(input_audios, additional_feats,layer)
|
758 |
-
import time
|
759 |
-
start = time.time()
|
760 |
-
latents = self.inference_codes(codes, spk_embeds, true_latents, latent_length, additional_feats, \
|
761 |
-
guidance_scale=guidance_scale, num_steps=num_steps, \
|
762 |
-
disable_progress=disable_progress,scenario=scenario)
|
763 |
-
return latents,time.time()-start
|
764 |
-
|
765 |
-
def prepare_latents(self, batch_size, num_frames, num_channels_latents, dtype, device):
|
766 |
-
divisor = 4
|
767 |
-
shape = (batch_size, num_channels_latents, num_frames, 32)
|
768 |
-
if(num_frames%divisor>0):
|
769 |
-
num_frames = round(num_frames/float(divisor))*divisor
|
770 |
-
shape = (batch_size, num_channels_latents, num_frames, 32)
|
771 |
-
latents = randn_tensor(shape, generator=None, device=device, dtype=dtype)
|
772 |
-
return latents
|
773 |
-
|
774 |
-
|
|
|
1 |
+
import yaml
|
2 |
+
import random
|
3 |
+
import inspect
|
4 |
+
import numpy as np
|
5 |
+
from tqdm import tqdm
|
6 |
+
import typing as tp
|
7 |
+
from abc import ABC
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
import torch.nn.functional as F
|
12 |
+
import torchaudio
|
13 |
+
|
14 |
+
from einops import repeat
|
15 |
+
from tools.torch_tools import wav_to_fbank
|
16 |
+
|
17 |
+
import diffusers
|
18 |
+
from diffusers.utils.torch_utils import randn_tensor
|
19 |
+
from diffusers import DDPMScheduler
|
20 |
+
from models.transformer_2d_flow import Transformer2DModel
|
21 |
+
from transformers import AutoFeatureExtractor, Wav2Vec2BertModel,HubertModel
|
22 |
+
# from tools.get_mulan import get_mulan
|
23 |
+
from third_party.wespeaker.extract_embd import XVECModel
|
24 |
+
# from libs.rvq2 import RVQEmbedding
|
25 |
+
from libs.rvq.descript_quantize3_4layer_freezelayer1 import ResidualVectorQuantize
|
26 |
+
|
27 |
+
from models_gpt.models.gpt2_rope2_time_new_correct_mask_noncasual_reflow import GPT2Model
|
28 |
+
from models_gpt.models.gpt2_config import GPT2Config
|
29 |
+
|
30 |
+
from torch.cuda.amp import autocast
|
31 |
+
|
32 |
+
|
33 |
+
from our_MERT_BESTRQ.test import load_model
|
34 |
+
|
35 |
+
class HubertModelWithFinalProj(HubertModel):
|
36 |
+
def __init__(self, config):
|
37 |
+
super().__init__(config)
|
38 |
+
|
39 |
+
# The final projection layer is only used for backward compatibility.
|
40 |
+
# Following https://github.com/auspicious3000/contentvec/issues/6
|
41 |
+
# Remove this layer is necessary to achieve the desired outcome.
|
42 |
+
print("hidden_size:",config.hidden_size)
|
43 |
+
print("classifier_proj_size:",config.classifier_proj_size)
|
44 |
+
self.final_proj = nn.Linear(config.hidden_size, config.classifier_proj_size)
|
45 |
+
|
46 |
+
|
47 |
+
class SampleProcessor(torch.nn.Module):
|
48 |
+
def project_sample(self, x: torch.Tensor):
|
49 |
+
"""Project the original sample to the 'space' where the diffusion will happen."""
|
50 |
+
"""Project back from diffusion space to the actual sample space."""
|
51 |
+
return z
|
52 |
+
|
53 |
+
class Feature1DProcessor(SampleProcessor):
|
54 |
+
def __init__(self, dim: int = 100, power_std = 1., \
|
55 |
+
num_samples: int = 100_000, cal_num_frames: int = 600):
|
56 |
+
super().__init__()
|
57 |
+
|
58 |
+
self.num_samples = num_samples
|
59 |
+
self.dim = dim
|
60 |
+
self.power_std = power_std
|
61 |
+
self.cal_num_frames = cal_num_frames
|
62 |
+
self.register_buffer('counts', torch.zeros(1))
|
63 |
+
self.register_buffer('sum_x', torch.zeros(dim))
|
64 |
+
self.register_buffer('sum_x2', torch.zeros(dim))
|
65 |
+
self.register_buffer('sum_target_x2', torch.zeros(dim))
|
66 |
+
self.counts: torch.Tensor
|
67 |
+
self.sum_x: torch.Tensor
|
68 |
+
self.sum_x2: torch.Tensor
|
69 |
+
|
70 |
+
@property
|
71 |
+
def mean(self):
|
72 |
+
mean = self.sum_x / self.counts
|
73 |
+
if(self.counts < 10):
|
74 |
+
mean = torch.zeros_like(mean)
|
75 |
+
return mean
|
76 |
+
|
77 |
+
@property
|
78 |
+
def std(self):
|
79 |
+
std = (self.sum_x2 / self.counts - self.mean**2).clamp(min=0).sqrt()
|
80 |
+
if(self.counts < 10):
|
81 |
+
std = torch.ones_like(std)
|
82 |
+
return std
|
83 |
+
|
84 |
+
@property
|
85 |
+
def target_std(self):
|
86 |
+
return 1
|
87 |
+
|
88 |
+
def project_sample(self, x: torch.Tensor):
|
89 |
+
assert x.dim() == 3
|
90 |
+
if self.counts.item() < self.num_samples:
|
91 |
+
self.counts += len(x)
|
92 |
+
self.sum_x += x[:,:,0:self.cal_num_frames].mean(dim=(2,)).sum(dim=0)
|
93 |
+
self.sum_x2 += x[:,:,0:self.cal_num_frames].pow(2).mean(dim=(2,)).sum(dim=0)
|
94 |
+
rescale = (self.target_std / self.std.clamp(min=1e-12)) ** self.power_std # same output size
|
95 |
+
x = (x - self.mean.view(1, -1, 1)) * rescale.view(1, -1, 1)
|
96 |
+
return x
|
97 |
+
|
98 |
+
def return_sample(self, x: torch.Tensor):
|
99 |
+
assert x.dim() == 3
|
100 |
+
rescale = (self.std / self.target_std) ** self.power_std
|
101 |
+
# print(rescale, self.mean)
|
102 |
+
x = x * rescale.view(1, -1, 1) + self.mean.view(1, -1, 1)
|
103 |
+
return x
|
104 |
+
|
105 |
+
def pad_or_tunc_tolen(prior_text_encoder_hidden_states, prior_text_mask, prior_prompt_embeds, len_size=77):
|
106 |
+
if(prior_text_encoder_hidden_states.shape[1]<len_size):
|
107 |
+
prior_text_encoder_hidden_states = torch.cat([prior_text_encoder_hidden_states, \
|
108 |
+
torch.zeros(prior_text_mask.shape[0], len_size-prior_text_mask.shape[1], \
|
109 |
+
prior_text_encoder_hidden_states.shape[2], device=prior_text_mask.device, \
|
110 |
+
dtype=prior_text_encoder_hidden_states.dtype)],1)
|
111 |
+
prior_text_mask = torch.cat([prior_text_mask, torch.zeros(prior_text_mask.shape[0], len_size-prior_text_mask.shape[1], device=prior_text_mask.device, dtype=prior_text_mask.dtype)],1)
|
112 |
+
else:
|
113 |
+
prior_text_encoder_hidden_states = prior_text_encoder_hidden_states[:,0:len_size]
|
114 |
+
prior_text_mask = prior_text_mask[:,0:len_size]
|
115 |
+
prior_text_encoder_hidden_states = prior_text_encoder_hidden_states.permute(0,2,1).contiguous()
|
116 |
+
return prior_text_encoder_hidden_states, prior_text_mask, prior_prompt_embeds
|
117 |
+
|
118 |
+
class BASECFM(torch.nn.Module, ABC):
|
119 |
+
def __init__(
|
120 |
+
self,
|
121 |
+
estimator,
|
122 |
+
mlp,
|
123 |
+
ssl_layer
|
124 |
+
):
|
125 |
+
super().__init__()
|
126 |
+
self.sigma_min = 1e-4
|
127 |
+
|
128 |
+
self.estimator = estimator
|
129 |
+
self.mlp = mlp
|
130 |
+
self.ssl_layer = ssl_layer
|
131 |
+
|
132 |
+
@torch.inference_mode()
|
133 |
+
def forward(self, mu, n_timesteps, temperature=1.0):
|
134 |
+
"""Forward diffusion
|
135 |
+
|
136 |
+
Args:
|
137 |
+
mu (torch.Tensor): output of encoder
|
138 |
+
shape: (batch_size, n_channels, mel_timesteps, n_feats)
|
139 |
+
n_timesteps (int): number of diffusion steps
|
140 |
+
temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
|
141 |
+
|
142 |
+
Returns:
|
143 |
+
sample: generated mel-spectrogram
|
144 |
+
shape: (batch_size, n_channels, mel_timesteps, n_feats)
|
145 |
+
"""
|
146 |
+
z = torch.randn_like(mu) * temperature
|
147 |
+
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device)
|
148 |
+
return self.solve_euler(z, t_span=t_span)
|
149 |
+
|
150 |
+
def solve_euler(self, x, latent_mask_input,incontext_x, incontext_length, t_span, mu,attention_mask, guidance_scale):
|
151 |
+
"""
|
152 |
+
Fixed euler solver for ODEs.
|
153 |
+
Args:
|
154 |
+
x (torch.Tensor): random noise
|
155 |
+
t_span (torch.Tensor): n_timesteps interpolated
|
156 |
+
shape: (n_timesteps + 1,)
|
157 |
+
mu (torch.Tensor): output of encoder
|
158 |
+
shape: (batch_size, n_channels, mel_timesteps, n_feats)
|
159 |
+
"""
|
160 |
+
t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
|
161 |
+
noise = x.clone()
|
162 |
+
|
163 |
+
# I am storing this because I can later plot it by putting a debugger here and saving it to a file
|
164 |
+
# Or in future might add like a return_all_steps flag
|
165 |
+
sol = []
|
166 |
+
|
167 |
+
for step in tqdm(range(1, len(t_span))):
|
168 |
+
# print("incontext_x.shape:",incontext_x.shape)
|
169 |
+
# print("noise.shape:",noise.shape)
|
170 |
+
# print("t.shape:",t.shape)
|
171 |
+
x[:,0:incontext_length,:] = (1 - (1 - self.sigma_min) * t) * noise[:,0:incontext_length,:] + t * incontext_x[:,0:incontext_length,:]
|
172 |
+
if(guidance_scale > 1.0):
|
173 |
+
|
174 |
+
model_input = torch.cat([ \
|
175 |
+
torch.cat([latent_mask_input, latent_mask_input], 0), \
|
176 |
+
torch.cat([incontext_x, incontext_x], 0), \
|
177 |
+
torch.cat([torch.zeros_like(mu), mu], 0), \
|
178 |
+
torch.cat([x, x], 0), \
|
179 |
+
], 2)
|
180 |
+
timestep=t.unsqueeze(-1).repeat(2)
|
181 |
+
|
182 |
+
dphi_dt = self.estimator(inputs_embeds=model_input, attention_mask=attention_mask,time_step=timestep).last_hidden_state
|
183 |
+
dphi_dt_uncond, dhpi_dt_cond = dphi_dt.chunk(2,0)
|
184 |
+
dphi_dt = dphi_dt_uncond + guidance_scale * (dhpi_dt_cond - dphi_dt_uncond)
|
185 |
+
else:
|
186 |
+
model_input = torch.cat([latent_mask_input, incontext_x, mu, x], 2)
|
187 |
+
timestep=t.unsqueeze(-1)
|
188 |
+
dphi_dt = self.estimator(inputs_embeds=model_input, attention_mask=attention_mask,time_step=timestep).last_hidden_state
|
189 |
+
|
190 |
+
dphi_dt = dphi_dt[: ,:, -x.shape[2]:]
|
191 |
+
# print("dphi_dt.shape:",dphi_dt.shape)
|
192 |
+
# print("x.shape:",x.shape)
|
193 |
+
|
194 |
+
x = x + dt * dphi_dt
|
195 |
+
t = t + dt
|
196 |
+
sol.append(x)
|
197 |
+
if step < len(t_span) - 1:
|
198 |
+
dt = t_span[step + 1] - t
|
199 |
+
|
200 |
+
return sol[-1]
|
201 |
+
|
202 |
+
def projection_loss(self,hidden_proj, bestrq_emb):
|
203 |
+
bsz = hidden_proj.shape[0]
|
204 |
+
|
205 |
+
hidden_proj_normalized = F.normalize(hidden_proj, dim=-1)
|
206 |
+
bestrq_emb_normalized = F.normalize(bestrq_emb, dim=-1)
|
207 |
+
|
208 |
+
proj_loss = -(hidden_proj_normalized * bestrq_emb_normalized).sum(dim=-1)
|
209 |
+
proj_loss = 1+proj_loss.mean()
|
210 |
+
|
211 |
+
return proj_loss
|
212 |
+
|
213 |
+
def compute_loss(self, x1, mu, latent_masks,attention_mask,wav2vec_embeds, validation_mode=False):
|
214 |
+
"""Computes diffusion loss
|
215 |
+
|
216 |
+
Args:
|
217 |
+
x1 (torch.Tensor): Target
|
218 |
+
shape: (batch_size, n_channels, mel_timesteps, n_feats)
|
219 |
+
mu (torch.Tensor): output of encoder
|
220 |
+
shape: (batch_size, n_channels, mel_timesteps, n_feats)
|
221 |
+
|
222 |
+
Returns:
|
223 |
+
loss: conditional flow matching loss
|
224 |
+
y: conditional flow
|
225 |
+
shape: (batch_size, n_channels, mel_timesteps, n_feats)
|
226 |
+
"""
|
227 |
+
b = mu[0].shape[0]
|
228 |
+
len_x = x1.shape[2]
|
229 |
+
# random timestep
|
230 |
+
if(validation_mode):
|
231 |
+
t = torch.ones([b, 1, 1], device=mu[0].device, dtype=mu[0].dtype) * 0.5
|
232 |
+
else:
|
233 |
+
t = torch.rand([b, 1, 1], device=mu[0].device, dtype=mu[0].dtype)
|
234 |
+
# sample noise p(x_0)
|
235 |
+
z = torch.randn_like(x1)
|
236 |
+
|
237 |
+
y = (1 - (1 - self.sigma_min) * t) * z + t * x1
|
238 |
+
u = x1 - (1 - self.sigma_min) * z
|
239 |
+
# print("y.shape:",y.shape)
|
240 |
+
#self.unet(inputs_embeds=model_input, attention_mask=attention_mask,encoder_hidden_states=text_embedding,encoder_attention_mask=txt_attn_mask,time_step=timesteps).last_hidden_state
|
241 |
+
model_input = torch.cat([*mu,y], 2)
|
242 |
+
t=t.squeeze(-1).squeeze(-1)
|
243 |
+
# print("model_input.shape:",model_input.shape)
|
244 |
+
# print("attention_mask.shape:",attention_mask.shape)
|
245 |
+
out = self.estimator(inputs_embeds=model_input, attention_mask=attention_mask,time_step=t,output_hidden_states=True)
|
246 |
+
hidden_layer = out.hidden_states[self.ssl_layer]
|
247 |
+
hidden_proj = self.mlp(hidden_layer)
|
248 |
+
# print("hidden_proj.shape:",hidden_proj.shape)
|
249 |
+
# print("mert_emb.shape:",mert_emb.shape)
|
250 |
+
# exit()
|
251 |
+
|
252 |
+
|
253 |
+
out = out.last_hidden_state
|
254 |
+
|
255 |
+
out=out[:,:,-len_x:]
|
256 |
+
# out=self.proj_out(out)
|
257 |
+
|
258 |
+
weight = (latent_masks > 1.5).unsqueeze(-1).repeat(1, 1, out.shape[-1]).float() + (latent_masks < 0.5).unsqueeze(-1).repeat(1, 1, out.shape[-1]).float() * 0.01
|
259 |
+
# print("out.shape",out.shape)
|
260 |
+
# print("u.shape",u.shape)
|
261 |
+
loss_re = F.mse_loss(out * weight, u * weight, reduction="sum") / weight.sum()
|
262 |
+
# print("hidden_proj.shape:",hidden_proj.shape)
|
263 |
+
# print("wav2vec_embeds.shape:",wav2vec_embeds.shape)
|
264 |
+
loss_cos = self.projection_loss(hidden_proj, wav2vec_embeds)
|
265 |
+
loss = loss_re + loss_cos * 0.5
|
266 |
+
# print("loss_cos:",loss_cos,loss_cos.device)
|
267 |
+
print("loss:",loss,loss.device)
|
268 |
+
# exit()
|
269 |
+
return loss, loss_re, loss_cos
|
270 |
+
|
271 |
+
class PromptCondAudioDiffusion(nn.Module):
|
272 |
+
def __init__(
|
273 |
+
self,
|
274 |
+
num_channels,
|
275 |
+
unet_model_name=None,
|
276 |
+
unet_model_config_path=None,
|
277 |
+
snr_gamma=None,
|
278 |
+
hubert_layer=None,
|
279 |
+
ssl_layer=None,
|
280 |
+
uncondition=True,
|
281 |
+
out_paint=False,
|
282 |
+
):
|
283 |
+
super().__init__()
|
284 |
+
|
285 |
+
assert unet_model_name is not None or unet_model_config_path is not None, "Either UNet pretrain model name or a config file path is required"
|
286 |
+
|
287 |
+
self.unet_model_name = unet_model_name
|
288 |
+
self.unet_model_config_path = unet_model_config_path
|
289 |
+
self.snr_gamma = snr_gamma
|
290 |
+
self.uncondition = uncondition
|
291 |
+
self.num_channels = num_channels
|
292 |
+
self.hubert_layer = hubert_layer
|
293 |
+
self.ssl_layer = ssl_layer
|
294 |
+
|
295 |
+
# https://huggingface.co/docs/diffusers/v0.14.0/en/api/schedulers/overview
|
296 |
+
self.normfeat = Feature1DProcessor(dim=64)
|
297 |
+
|
298 |
+
self.sample_rate = 48000
|
299 |
+
self.num_samples_perseg = self.sample_rate * 20 // 1000
|
300 |
+
self.rsp48toclap = torchaudio.transforms.Resample(48000, 24000)
|
301 |
+
self.rsq48towav2vec = torchaudio.transforms.Resample(48000, 16000)
|
302 |
+
# self.wav2vec = Wav2Vec2BertModel.from_pretrained("facebook/w2v-bert-2.0", trust_remote_code=True)
|
303 |
+
# self.wav2vec_processor = AutoFeatureExtractor.from_pretrained("facebook/w2v-bert-2.0", trust_remote_code=True)
|
304 |
+
self.bestrq = load_model(
|
305 |
+
model_dir='path/to/our-MERT/mert_fairseq',
|
306 |
+
checkpoint_dir='checkpoint-120000.pt',
|
307 |
+
)
|
308 |
+
self.rsq48tobestrq = torchaudio.transforms.Resample(48000, 24000)
|
309 |
+
self.rsq48tohubert = torchaudio.transforms.Resample(48000, 16000)
|
310 |
+
for v in self.bestrq.parameters():v.requires_grad = False
|
311 |
+
self.rvq_bestrq_emb = ResidualVectorQuantize(input_dim = 1024, n_codebooks = 2, codebook_size = 16_384, codebook_dim = 32, quantizer_dropout = 0.0, stale_tolerance=200)
|
312 |
+
# for v in self.rvq_bestrq_emb.parameters():
|
313 |
+
# print(v)
|
314 |
+
freeze_parameters='quantizers.0'
|
315 |
+
for name, param in self.rvq_bestrq_emb.named_parameters():
|
316 |
+
if freeze_parameters in name:
|
317 |
+
param.requires_grad = False
|
318 |
+
print("Freezing RVQ parameters:", name)
|
319 |
+
self.hubert = HubertModelWithFinalProj.from_pretrained("huggingface_cache/models--lengyue233--content-vec-best/snapshots/c0b9ba13db21beaa4053faae94c102ebe326fd68")
|
320 |
+
for v in self.hubert.parameters():v.requires_grad = False
|
321 |
+
self.zero_cond_embedding1 = nn.Parameter(torch.randn(32*32,))
|
322 |
+
# self.xvecmodel = XVECModel()
|
323 |
+
config = GPT2Config(n_positions=1000,n_layer=39,n_head=30,n_embd=1200)
|
324 |
+
unet = GPT2Model(config)
|
325 |
+
mlp = nn.Sequential(
|
326 |
+
nn.Linear(1200, 1024),
|
327 |
+
nn.SiLU(),
|
328 |
+
nn.Linear(1024, 1024),
|
329 |
+
nn.SiLU(),
|
330 |
+
nn.Linear(1024, 768)
|
331 |
+
)
|
332 |
+
self.set_from = "random"
|
333 |
+
self.cfm_wrapper = BASECFM(unet, mlp,self.ssl_layer)
|
334 |
+
self.mask_emb = torch.nn.Embedding(3, 48)
|
335 |
+
print("Transformer initialized from pretrain.")
|
336 |
+
torch.cuda.empty_cache()
|
337 |
+
# self.unet.set_attn_processor(AttnProcessor2_0())
|
338 |
+
# self.unet.set_use_memory_efficient_attention_xformers(True)
|
339 |
+
|
340 |
+
# self.start_embedding = nn.Parameter(torch.randn(1,1024))
|
341 |
+
# self.end_embedding = nn.Parameter(torch.randn(1,1024))
|
342 |
+
|
343 |
+
def compute_snr(self, timesteps):
|
344 |
+
"""
|
345 |
+
Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
|
346 |
+
"""
|
347 |
+
alphas_cumprod = self.noise_scheduler.alphas_cumprod
|
348 |
+
sqrt_alphas_cumprod = alphas_cumprod**0.5
|
349 |
+
sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
|
350 |
+
|
351 |
+
# Expand the tensors.
|
352 |
+
# Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026
|
353 |
+
sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
|
354 |
+
while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape):
|
355 |
+
sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
|
356 |
+
alpha = sqrt_alphas_cumprod.expand(timesteps.shape)
|
357 |
+
|
358 |
+
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
|
359 |
+
while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape):
|
360 |
+
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None]
|
361 |
+
sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape)
|
362 |
+
|
363 |
+
# Compute SNR.
|
364 |
+
snr = (alpha / sigma) ** 2
|
365 |
+
return snr
|
366 |
+
|
367 |
+
def preprocess_audio(self, input_audios, threshold=0.9):
|
368 |
+
assert len(input_audios.shape) == 2, input_audios.shape
|
369 |
+
norm_value = torch.ones_like(input_audios[:,0])
|
370 |
+
max_volume = input_audios.abs().max(dim=-1)[0]
|
371 |
+
norm_value[max_volume>threshold] = max_volume[max_volume>threshold] / threshold
|
372 |
+
return input_audios/norm_value.unsqueeze(-1)
|
373 |
+
|
374 |
+
def extract_wav2vec_embeds(self, input_audios,output_len):
|
375 |
+
wav2vec_stride = 2
|
376 |
+
|
377 |
+
wav2vec_embeds = self.hubert(self.rsq48tohubert(input_audios), output_hidden_states=True).hidden_states # 1, 4096, 1024
|
378 |
+
# print(wav2vec_embeds)
|
379 |
+
# print("audio.shape:",input_audios.shape)
|
380 |
+
wav2vec_embeds_last=wav2vec_embeds[self.hubert_layer]
|
381 |
+
# print("wav2vec_embeds_last.shape:",wav2vec_embeds_last.shape)
|
382 |
+
wav2vec_embeds_last=torch.nn.functional.interpolate(wav2vec_embeds_last.permute(0, 2, 1), size=output_len, mode='linear', align_corners=False).permute(0, 2, 1)
|
383 |
+
return wav2vec_embeds_last
|
384 |
+
|
385 |
+
def extract_mert_embeds(self, input_audios):
|
386 |
+
prompt_stride = 3
|
387 |
+
inputs = self.clap_embd_extractor.mulan.audio.processor(self.rsp48toclap(input_audios), sampling_rate=self.clap_embd_extractor.mulan.audio.sr, return_tensors="pt")
|
388 |
+
input_values = inputs['input_values'].squeeze(0).to(input_audios.device, dtype = input_audios.dtype)
|
389 |
+
prompt_embeds = self.clap_embd_extractor.mulan.audio.model(input_values, output_hidden_states=True).hidden_states # batch_size, Time steps, 1024
|
390 |
+
mert_emb= prompt_embeds[-1]
|
391 |
+
mert_emb = torch.nn.functional.interpolate(mert_emb.permute(0, 2, 1), size=500, mode='linear', align_corners=False).permute(0, 2, 1)
|
392 |
+
|
393 |
+
return mert_emb
|
394 |
+
|
395 |
+
def extract_bestrq_embeds(self, input_audio_0,input_audio_1,layer):
|
396 |
+
self.bestrq.eval()
|
397 |
+
# print("audio shape:",input_audio_0.shape)
|
398 |
+
input_wav_mean = (input_audio_0 + input_audio_1) / 2.0
|
399 |
+
# print("input_wav_mean.shape:",input_wav_mean.shape)
|
400 |
+
# input_wav_mean = torch.randn(2,1720320*2).to(input_audio_0.device)
|
401 |
+
input_wav_mean = self.bestrq(self.rsq48tobestrq(input_wav_mean), features_only = True)
|
402 |
+
layer_results = input_wav_mean['layer_results']
|
403 |
+
# print("layer_results.shape:",layer_results[layer].shape)
|
404 |
+
bestrq_emb = layer_results[layer]
|
405 |
+
bestrq_emb = bestrq_emb.permute(0,2,1).contiguous()
|
406 |
+
#[b,t,1024] t=t/960
|
407 |
+
#35.84s->batch,896,1024
|
408 |
+
return bestrq_emb
|
409 |
+
|
410 |
+
|
411 |
+
def extract_spk_embeds(self, input_audios):
|
412 |
+
spk_embeds = self.xvecmodel(self.rsq48towav2vec(input_audios))
|
413 |
+
spk_embeds = self.spk_linear(spk_embeds).reshape(spk_embeds.shape[0], 16, 1, 32)
|
414 |
+
return spk_embeds
|
415 |
+
|
416 |
+
def extract_lyric_feats(self, lyric):
|
417 |
+
with torch.no_grad():
|
418 |
+
try:
|
419 |
+
text_encoder_hidden_states, text_mask, text_prompt_embeds = self.clap_embd_extractor(texts = lyric, return_one=False)
|
420 |
+
except:
|
421 |
+
text_encoder_hidden_states, text_mask, text_prompt_embeds = self.clap_embd_extractor(texts = [""] * len(lyric), return_one=False)
|
422 |
+
text_encoder_hidden_states = text_encoder_hidden_states.to(self.device)
|
423 |
+
text_mask = text_mask.to(self.device)
|
424 |
+
text_encoder_hidden_states, text_mask, text_prompt_embeds = \
|
425 |
+
pad_or_tunc_tolen(text_encoder_hidden_states, text_mask, text_prompt_embeds)
|
426 |
+
text_encoder_hidden_states = text_encoder_hidden_states.permute(0,2,1).contiguous()
|
427 |
+
return text_encoder_hidden_states, text_mask
|
428 |
+
|
429 |
+
def extract_energy_bar(self, input_audios):
|
430 |
+
if(input_audios.shape[-1] % self.num_samples_perseg > 0):
|
431 |
+
energy_bar = input_audios[:,:-1 * (input_audios.shape[-1] % self.num_samples_perseg)].reshape(input_audios.shape[0],-1,self.num_samples_perseg)
|
432 |
+
else:
|
433 |
+
energy_bar = input_audios.reshape(input_audios.shape[0],-1,self.num_samples_perseg)
|
434 |
+
energy_bar = (energy_bar.pow(2.0).mean(-1).sqrt() + 1e-6).log10() * 20 # B T
|
435 |
+
energy_bar = (energy_bar / 2.0 + 16).clamp(0,16).int()
|
436 |
+
energy_embedding = self.energy_embedding(energy_bar)
|
437 |
+
energy_embedding = energy_embedding.view(energy_embedding.shape[0], energy_embedding.shape[1] // 2, 2, 32).reshape(energy_embedding.shape[0], energy_embedding.shape[1] // 2, 64).permute(0,2,1) # b 128 t
|
438 |
+
return energy_embedding
|
439 |
+
|
440 |
+
def forward(self, input_audios, lyric, latents, latent_masks, validation_mode=False, \
|
441 |
+
additional_feats = ['spk', 'lyric'], \
|
442 |
+
train_rvq=True, train_ssl=False,layer=5):
|
443 |
+
if not hasattr(self,"device"):
|
444 |
+
self.device = input_audios.device
|
445 |
+
if not hasattr(self,"dtype"):
|
446 |
+
self.dtype = input_audios.dtype
|
447 |
+
device = self.device
|
448 |
+
input_audio_0 = input_audios[:,0,:]
|
449 |
+
input_audio_1 = input_audios[:,1,:]
|
450 |
+
input_audio_0 = self.preprocess_audio(input_audio_0)
|
451 |
+
input_audio_1 = self.preprocess_audio(input_audio_1)
|
452 |
+
input_audios_wav2vec = (input_audio_0 + input_audio_1) / 2.0
|
453 |
+
# energy_embedding = self.extract_energy_bar(input_audios)
|
454 |
+
# print("energy_embedding.shape:",energy_embedding.shape)
|
455 |
+
# with autocast(enabled=False):
|
456 |
+
if(train_ssl):
|
457 |
+
self.wav2vec.train()
|
458 |
+
wav2vec_embeds = self.extract_wav2vec_embeds(input_audios)
|
459 |
+
self.clap_embd_extractor.train()
|
460 |
+
prompt_embeds = self.extract_mert_embeds(input_audios)
|
461 |
+
if('spk' in additional_feats):
|
462 |
+
self.xvecmodel.train()
|
463 |
+
spk_embeds = self.extract_spk_embeds(input_audios).repeat(1,1,prompt_embeds.shape[-1]//2,1)
|
464 |
+
else:
|
465 |
+
with torch.no_grad():
|
466 |
+
with autocast(enabled=False):
|
467 |
+
bestrq_emb = self.extract_bestrq_embeds(input_audio_0,input_audio_1,layer)
|
468 |
+
# mert_emb = self.extract_mert_embeds(input_audios_mert)
|
469 |
+
|
470 |
+
wav2vec_embeds = self.extract_wav2vec_embeds(input_audios_wav2vec,bestrq_emb.shape[2])
|
471 |
+
|
472 |
+
bestrq_emb = bestrq_emb.detach()
|
473 |
+
if('lyric' in additional_feats):
|
474 |
+
text_encoder_hidden_states, text_mask = self.extract_lyric_feats(lyric)
|
475 |
+
else:
|
476 |
+
text_encoder_hidden_states, text_mask = None, None
|
477 |
+
|
478 |
+
|
479 |
+
if(train_rvq):
|
480 |
+
random_num=random.random()
|
481 |
+
if(random_num<0.6):
|
482 |
+
rvq_layer = 1
|
483 |
+
elif(random_num<0.8):
|
484 |
+
rvq_layer = 2
|
485 |
+
else:
|
486 |
+
rvq_layer = 4
|
487 |
+
quantized_bestrq_emb, _, _, commitment_loss_bestrq_emb, codebook_loss_bestrq_emb,_ = self.rvq_bestrq_emb(bestrq_emb,n_quantizers=rvq_layer) # b,d,t
|
488 |
+
else:
|
489 |
+
bestrq_emb = bestrq_emb.float()
|
490 |
+
self.rvq_bestrq_emb.eval()
|
491 |
+
# with autocast(enabled=False):
|
492 |
+
quantized_bestrq_emb, _, _, commitment_loss_bestrq_emb, codebook_loss_bestrq_emb,_ = self.rvq_bestrq_emb(bestrq_emb) # b,d,t
|
493 |
+
commitment_loss_bestrq_emb = commitment_loss_bestrq_emb.detach()
|
494 |
+
codebook_loss_bestrq_emb = codebook_loss_bestrq_emb.detach()
|
495 |
+
quantized_bestrq_emb = quantized_bestrq_emb.detach()
|
496 |
+
|
497 |
+
commitment_loss = commitment_loss_bestrq_emb
|
498 |
+
codebook_loss = codebook_loss_bestrq_emb
|
499 |
+
|
500 |
+
|
501 |
+
alpha=1
|
502 |
+
quantized_bestrq_emb = quantized_bestrq_emb * alpha + bestrq_emb * (1-alpha)
|
503 |
+
|
504 |
+
# print("quantized_bestrq_emb.shape:",quantized_bestrq_emb.shape)
|
505 |
+
# print("latent_masks.shape:",latent_masks.shape)
|
506 |
+
# quantized_bestrq_emb = torch.nn.functional.interpolate(quantized_bestrq_emb, size=(int(quantized_bestrq_emb.shape[-1]/999*937),), mode='linear', align_corners=True)
|
507 |
+
|
508 |
+
|
509 |
+
|
510 |
+
scenario = np.random.choice(['start_seg', 'other_seg'])
|
511 |
+
if(scenario == 'other_seg'):
|
512 |
+
for binx in range(input_audios.shape[0]):
|
513 |
+
# latent_masks[binx,0:64] = 1
|
514 |
+
latent_masks[binx,0:random.randint(64,128)] = 1
|
515 |
+
quantized_bestrq_emb = quantized_bestrq_emb.permute(0,2,1).contiguous()
|
516 |
+
# print("quantized_bestrq_emb.shape:",quantized_bestrq_emb.shape)
|
517 |
+
# print("quantized_bestrq_emb1.shape:",quantized_bestrq_emb.shape)
|
518 |
+
# print("latent_masks.shape:",latent_masks.shape)
|
519 |
+
quantized_bestrq_emb = (latent_masks > 0.5).unsqueeze(-1) * quantized_bestrq_emb \
|
520 |
+
+ (latent_masks < 0.5).unsqueeze(-1) * self.zero_cond_embedding1.reshape(1,1,1024)
|
521 |
+
|
522 |
+
|
523 |
+
|
524 |
+
|
525 |
+
if self.uncondition:
|
526 |
+
mask_indices = [k for k in range(quantized_bestrq_emb.shape[0]) if random.random() < 0.1]
|
527 |
+
if len(mask_indices) > 0:
|
528 |
+
quantized_bestrq_emb[mask_indices] = 0
|
529 |
+
# print("latents.shape:",latents.shape)
|
530 |
+
latents = latents.permute(0,2,1).contiguous()
|
531 |
+
latents = self.normfeat.project_sample(latents)
|
532 |
+
latents = latents.permute(0,2,1).contiguous()
|
533 |
+
incontext_latents = latents * ((latent_masks > 0.5) * (latent_masks < 1.5)).unsqueeze(-1).float()
|
534 |
+
attention_mask=(latent_masks > 0.5)
|
535 |
+
B, L = attention_mask.size()
|
536 |
+
attention_mask = attention_mask.view(B, 1, L)
|
537 |
+
attention_mask = attention_mask * attention_mask.transpose(-1, -2)
|
538 |
+
attention_mask = attention_mask.unsqueeze(1)
|
539 |
+
# print("incontext_latents.shape:",incontext_latents.shape)
|
540 |
+
# print("quantized_bestrq_emb.shape:",quantized_bestrq_emb.shape)
|
541 |
+
latent_mask_input = self.mask_emb(latent_masks)
|
542 |
+
#64+48+64+1024
|
543 |
+
loss,loss_re, loss_cos = self.cfm_wrapper.compute_loss(latents, [latent_mask_input,incontext_latents, quantized_bestrq_emb], latent_masks,attention_mask,wav2vec_embeds, validation_mode=validation_mode)
|
544 |
+
return loss,loss_re, loss_cos, commitment_loss.mean(), codebook_loss.mean()
|
545 |
+
|
546 |
+
def init_device_dtype(self, device, dtype):
|
547 |
+
self.device = device
|
548 |
+
self.dtype = dtype
|
549 |
+
|
550 |
+
@torch.no_grad()
|
551 |
+
def fetch_codes(self, input_audios, additional_feats,layer,rvq_num=1):
|
552 |
+
input_audio_0 = input_audios[[0],:]
|
553 |
+
input_audio_1 = input_audios[[1],:]
|
554 |
+
input_audio_0 = self.preprocess_audio(input_audio_0)
|
555 |
+
input_audio_1 = self.preprocess_audio(input_audio_1)
|
556 |
+
|
557 |
+
self.bestrq.eval()
|
558 |
+
|
559 |
+
# bestrq_middle,bestrq_last = self.extract_bestrq_embeds(input_audios)
|
560 |
+
# bestrq_middle = bestrq_middle.detach()
|
561 |
+
# bestrq_last = bestrq_last.detach()
|
562 |
+
bestrq_emb = self.extract_bestrq_embeds(input_audio_0,input_audio_1,layer)
|
563 |
+
bestrq_emb = bestrq_emb.detach()
|
564 |
+
|
565 |
+
# self.rvq_bestrq_middle.eval()
|
566 |
+
# quantized_bestrq_middle, codes_bestrq_middle, *_ = self.rvq_bestrq_middle(bestrq_middle) # b,d,t
|
567 |
+
# self.rvq_bestrq_last.eval()
|
568 |
+
# quantized_bestrq_last, codes_bestrq_last, *_ = self.rvq_bestrq_last(bestrq_last) # b,d,t
|
569 |
+
|
570 |
+
self.rvq_bestrq_emb.eval()
|
571 |
+
quantized_bestrq_emb, codes_bestrq_emb, *_ = self.rvq_bestrq_emb(bestrq_emb)
|
572 |
+
codes_bestrq_emb = codes_bestrq_emb[:,:rvq_num,:]
|
573 |
+
# print("codes_bestrq_emb.shape:",codes_bestrq_emb.shape)
|
574 |
+
# exit()
|
575 |
+
|
576 |
+
|
577 |
+
if('spk' in additional_feats):
|
578 |
+
self.xvecmodel.eval()
|
579 |
+
spk_embeds = self.extract_spk_embeds(input_audios)
|
580 |
+
else:
|
581 |
+
spk_embeds = None
|
582 |
+
|
583 |
+
# return [codes_prompt, codes_wav2vec], [prompt_embeds, wav2vec_embeds], spk_embeds
|
584 |
+
# return [codes_prompt_7, codes_prompt_13, codes_prompt_20, codes_wav2vec_half, codes_wav2vec_last], [prompt_embeds_7, prompt_embeds_13, prompt_embeds_20, wav2vec_embeds_half, wav2vec_embeds_last], spk_embeds
|
585 |
+
# return [codes_bestrq_middle, codes_bestrq_last], [bestrq_middle, bestrq_last], spk_embeds
|
586 |
+
return [codes_bestrq_emb], [bestrq_emb], spk_embeds
|
587 |
+
# return [codes_prompt_13, codes_wav2vec_last], [prompt_embeds_13, wav2vec_embeds_last], spk_embeds
|
588 |
+
|
589 |
+
@torch.no_grad()
|
590 |
+
def fetch_codes_batch(self, input_audios, additional_feats,layer,rvq_num=1):
|
591 |
+
input_audio_0 = input_audios[:,0,:]
|
592 |
+
input_audio_1 = input_audios[:,1,:]
|
593 |
+
input_audio_0 = self.preprocess_audio(input_audio_0)
|
594 |
+
input_audio_1 = self.preprocess_audio(input_audio_1)
|
595 |
+
|
596 |
+
self.bestrq.eval()
|
597 |
+
|
598 |
+
# bestrq_middle,bestrq_last = self.extract_bestrq_embeds(input_audios)
|
599 |
+
# bestrq_middle = bestrq_middle.detach()
|
600 |
+
# bestrq_last = bestrq_last.detach()
|
601 |
+
bestrq_emb = self.extract_bestrq_embeds(input_audio_0,input_audio_1,layer)
|
602 |
+
bestrq_emb = bestrq_emb.detach()
|
603 |
+
|
604 |
+
# self.rvq_bestrq_middle.eval()
|
605 |
+
# quantized_bestrq_middle, codes_bestrq_middle, *_ = self.rvq_bestrq_middle(bestrq_middle) # b,d,t
|
606 |
+
# self.rvq_bestrq_last.eval()
|
607 |
+
# quantized_bestrq_last, codes_bestrq_last, *_ = self.rvq_bestrq_last(bestrq_last) # b,d,t
|
608 |
+
|
609 |
+
self.rvq_bestrq_emb.eval()
|
610 |
+
quantized_bestrq_emb, codes_bestrq_emb, *_ = self.rvq_bestrq_emb(bestrq_emb)
|
611 |
+
# print("codes_bestrq_emb.shape:",codes_bestrq_emb.shape)
|
612 |
+
codes_bestrq_emb = codes_bestrq_emb[:,:rvq_num,:]
|
613 |
+
# print("codes_bestrq_emb.shape:",codes_bestrq_emb.shape)
|
614 |
+
# exit()
|
615 |
+
|
616 |
+
|
617 |
+
if('spk' in additional_feats):
|
618 |
+
self.xvecmodel.eval()
|
619 |
+
spk_embeds = self.extract_spk_embeds(input_audios)
|
620 |
+
else:
|
621 |
+
spk_embeds = None
|
622 |
+
|
623 |
+
# return [codes_prompt, codes_wav2vec], [prompt_embeds, wav2vec_embeds], spk_embeds
|
624 |
+
# return [codes_prompt_7, codes_prompt_13, codes_prompt_20, codes_wav2vec_half, codes_wav2vec_last], [prompt_embeds_7, prompt_embeds_13, prompt_embeds_20, wav2vec_embeds_half, wav2vec_embeds_last], spk_embeds
|
625 |
+
# return [codes_bestrq_middle, codes_bestrq_last], [bestrq_middle, bestrq_last], spk_embeds
|
626 |
+
return [codes_bestrq_emb], [bestrq_emb], spk_embeds
|
627 |
+
# return [codes_prompt_13, codes_wav2vec_last], [prompt_embeds_13, wav2vec_embeds_last], spk_embeds
|
628 |
+
|
629 |
+
@torch.no_grad()
|
630 |
+
def fetch_codes_batch_ds(self, input_audios, additional_feats, layer, rvq_num=1, ds=250):
|
631 |
+
input_audio_0 = input_audios[:,0,:]
|
632 |
+
input_audio_1 = input_audios[:,1,:]
|
633 |
+
input_audio_0 = self.preprocess_audio(input_audio_0)
|
634 |
+
input_audio_1 = self.preprocess_audio(input_audio_1)
|
635 |
+
|
636 |
+
self.bestrq.eval()
|
637 |
+
|
638 |
+
# bestrq_middle,bestrq_last = self.extract_bestrq_embeds(input_audios)
|
639 |
+
# bestrq_middle = bestrq_middle.detach()
|
640 |
+
# bestrq_last = bestrq_last.detach()
|
641 |
+
bestrq_emb = self.extract_bestrq_embeds(input_audio_0,input_audio_1,layer)
|
642 |
+
bestrq_emb = bestrq_emb.detach()
|
643 |
+
|
644 |
+
# self.rvq_bestrq_middle.eval()
|
645 |
+
# quantized_bestrq_middle, codes_bestrq_middle, *_ = self.rvq_bestrq_middle(bestrq_middle) # b,d,t
|
646 |
+
# self.rvq_bestrq_last.eval()
|
647 |
+
# quantized_bestrq_last, codes_bestrq_last, *_ = self.rvq_bestrq_last(bestrq_last) # b,d,t
|
648 |
+
|
649 |
+
self.rvq_bestrq_emb.eval()
|
650 |
+
bestrq_emb = torch.nn.functional.avg_pool1d(bestrq_emb, kernel_size=ds, stride=ds)
|
651 |
+
quantized_bestrq_emb, codes_bestrq_emb, *_ = self.rvq_bestrq_emb(bestrq_emb)
|
652 |
+
# print("codes_bestrq_emb.shape:",codes_bestrq_emb.shape)
|
653 |
+
codes_bestrq_emb = codes_bestrq_emb[:,:rvq_num,:]
|
654 |
+
# print("codes_bestrq_emb.shape:",codes_bestrq_emb.shape)
|
655 |
+
# exit()
|
656 |
+
|
657 |
+
|
658 |
+
if('spk' in additional_feats):
|
659 |
+
self.xvecmodel.eval()
|
660 |
+
spk_embeds = self.extract_spk_embeds(input_audios)
|
661 |
+
else:
|
662 |
+
spk_embeds = None
|
663 |
+
|
664 |
+
# return [codes_prompt, codes_wav2vec], [prompt_embeds, wav2vec_embeds], spk_embeds
|
665 |
+
# return [codes_prompt_7, codes_prompt_13, codes_prompt_20, codes_wav2vec_half, codes_wav2vec_last], [prompt_embeds_7, prompt_embeds_13, prompt_embeds_20, wav2vec_embeds_half, wav2vec_embeds_last], spk_embeds
|
666 |
+
# return [codes_bestrq_middle, codes_bestrq_last], [bestrq_middle, bestrq_last], spk_embeds
|
667 |
+
return [codes_bestrq_emb], [bestrq_emb], spk_embeds
|
668 |
+
# return [codes_prompt_13, codes_wav2vec_last], [prompt_embeds_13, wav2vec_embeds_last], spk_embeds
|
669 |
+
|
670 |
+
@torch.no_grad()
|
671 |
+
def inference_codes(self, codes, spk_embeds, true_latents, latent_length, additional_feats, incontext_length=127,
|
672 |
+
guidance_scale=2, num_steps=20,
|
673 |
+
disable_progress=True, scenario='start_seg'):
|
674 |
+
classifier_free_guidance = guidance_scale > 1.0
|
675 |
+
device = self.device
|
676 |
+
dtype = self.dtype
|
677 |
+
# codes_bestrq_middle, codes_bestrq_last = codes
|
678 |
+
codes_bestrq_emb = codes[0]
|
679 |
+
|
680 |
+
|
681 |
+
batch_size = codes_bestrq_emb.shape[0]
|
682 |
+
|
683 |
+
|
684 |
+
quantized_bestrq_emb,_,_=self.rvq_bestrq_emb.from_codes(codes_bestrq_emb)
|
685 |
+
# quantized_bestrq_emb = torch.nn.functional.interpolate(quantized_bestrq_emb, size=(int(quantized_bestrq_emb.shape[-1]/999*937),), mode='linear', align_corners=True)
|
686 |
+
quantized_bestrq_emb = quantized_bestrq_emb.permute(0,2,1).contiguous()
|
687 |
+
print("quantized_bestrq_emb.shape:",quantized_bestrq_emb.shape)
|
688 |
+
# quantized_bestrq_emb = torch.nn.functional.interpolate(quantized_bestrq_emb, size=(int(quantized_bestrq_emb.shape[-1]/999*937),), mode='linear', align_corners=True)
|
689 |
+
|
690 |
+
|
691 |
+
|
692 |
+
|
693 |
+
if('spk' in additional_feats):
|
694 |
+
spk_embeds = spk_embeds.repeat(1,1,quantized_bestrq_emb.shape[-2],1).detach()
|
695 |
+
|
696 |
+
num_frames = quantized_bestrq_emb.shape[1]
|
697 |
+
|
698 |
+
num_channels_latents = self.num_channels
|
699 |
+
shape = (batch_size, num_frames, 64)
|
700 |
+
latents = randn_tensor(shape, generator=None, device=device, dtype=dtype)
|
701 |
+
|
702 |
+
|
703 |
+
|
704 |
+
latent_masks = torch.zeros(latents.shape[0], latents.shape[1], dtype=torch.int64, device=latents.device)
|
705 |
+
latent_masks[:,0:latent_length] = 2
|
706 |
+
if(scenario=='other_seg'):
|
707 |
+
latent_masks[:,0:incontext_length] = 1
|
708 |
+
|
709 |
+
|
710 |
+
|
711 |
+
quantized_bestrq_emb = (latent_masks > 0.5).unsqueeze(-1) * quantized_bestrq_emb \
|
712 |
+
+ (latent_masks < 0.5).unsqueeze(-1) * self.zero_cond_embedding1.reshape(1,1,1024)
|
713 |
+
true_latents = true_latents.permute(0,2,1).contiguous()
|
714 |
+
true_latents = self.normfeat.project_sample(true_latents)
|
715 |
+
true_latents = true_latents.permute(0,2,1).contiguous()
|
716 |
+
incontext_latents = true_latents * ((latent_masks > 0.5) * (latent_masks < 1.5)).unsqueeze(-1).float()
|
717 |
+
incontext_length = ((latent_masks > 0.5) * (latent_masks < 1.5)).sum(-1)[0]
|
718 |
+
|
719 |
+
|
720 |
+
attention_mask=(latent_masks > 0.5)
|
721 |
+
B, L = attention_mask.size()
|
722 |
+
attention_mask = attention_mask.view(B, 1, L)
|
723 |
+
attention_mask = attention_mask * attention_mask.transpose(-1, -2)
|
724 |
+
attention_mask = attention_mask.unsqueeze(1)
|
725 |
+
latent_mask_input = self.mask_emb(latent_masks)
|
726 |
+
|
727 |
+
if('spk' in additional_feats):
|
728 |
+
# additional_model_input = torch.cat([quantized_bestrq_middle, quantized_bestrq_last, spk_embeds],1)
|
729 |
+
additional_model_input = torch.cat([quantized_bestrq_emb, spk_embeds],1)
|
730 |
+
else:
|
731 |
+
# additional_model_input = torch.cat([quantized_bestrq_middle, quantized_bestrq_last],1)
|
732 |
+
additional_model_input = torch.cat([quantized_bestrq_emb],1)
|
733 |
+
|
734 |
+
temperature = 1.0
|
735 |
+
t_span = torch.linspace(0, 1, num_steps + 1, device=quantized_bestrq_emb.device)
|
736 |
+
latents = self.cfm_wrapper.solve_euler(latents * temperature, latent_mask_input,incontext_latents, incontext_length, t_span, additional_model_input,attention_mask, guidance_scale)
|
737 |
+
|
738 |
+
latents[:,0:incontext_length,:] = incontext_latents[:,0:incontext_length,:]
|
739 |
+
latents = latents.permute(0,2,1).contiguous()
|
740 |
+
latents = self.normfeat.return_sample(latents)
|
741 |
+
# latents = latents.permute(0,2,1).contiguous()
|
742 |
+
return latents
|
743 |
+
|
744 |
+
@torch.no_grad()
|
745 |
+
def inference(self, input_audios, lyric, true_latents, latent_length, additional_feats, guidance_scale=2, num_steps=20,
|
746 |
+
disable_progress=True,layer=5,scenario='start_seg',rvq_num=1):
|
747 |
+
codes, embeds, spk_embeds = self.fetch_codes(input_audios, additional_feats,layer,rvq_num)
|
748 |
+
|
749 |
+
latents = self.inference_codes(codes, spk_embeds, true_latents, latent_length, additional_feats, \
|
750 |
+
guidance_scale=guidance_scale, num_steps=num_steps, \
|
751 |
+
disable_progress=disable_progress,scenario=scenario)
|
752 |
+
return latents
|
753 |
+
|
754 |
+
@torch.no_grad()
|
755 |
+
def inference_rtf(self, input_audios, lyric, true_latents, latent_length, additional_feats, guidance_scale=2, num_steps=20,
|
756 |
+
disable_progress=True,layer=5,scenario='start_seg'):
|
757 |
+
codes, embeds, spk_embeds = self.fetch_codes(input_audios, additional_feats,layer)
|
758 |
+
import time
|
759 |
+
start = time.time()
|
760 |
+
latents = self.inference_codes(codes, spk_embeds, true_latents, latent_length, additional_feats, \
|
761 |
+
guidance_scale=guidance_scale, num_steps=num_steps, \
|
762 |
+
disable_progress=disable_progress,scenario=scenario)
|
763 |
+
return latents,time.time()-start
|
764 |
+
|
765 |
+
def prepare_latents(self, batch_size, num_frames, num_channels_latents, dtype, device):
|
766 |
+
divisor = 4
|
767 |
+
shape = (batch_size, num_channels_latents, num_frames, 32)
|
768 |
+
if(num_frames%divisor>0):
|
769 |
+
num_frames = round(num_frames/float(divisor))*divisor
|
770 |
+
shape = (batch_size, num_channels_latents, num_frames, 32)
|
771 |
+
latents = randn_tensor(shape, generator=None, device=device, dtype=dtype)
|
772 |
+
return latents
|
773 |
+
|
774 |
+
|
codeclm/tokenizer/Flow1dVAE/model_4rvq.py
CHANGED
@@ -1,774 +1,774 @@
|
|
1 |
-
import yaml
|
2 |
-
import random
|
3 |
-
import inspect
|
4 |
-
import numpy as np
|
5 |
-
from tqdm import tqdm
|
6 |
-
import typing as tp
|
7 |
-
from abc import ABC
|
8 |
-
|
9 |
-
import torch
|
10 |
-
import torch.nn as nn
|
11 |
-
import torch.nn.functional as F
|
12 |
-
import torchaudio
|
13 |
-
|
14 |
-
from einops import repeat
|
15 |
-
from tools.torch_tools import wav_to_fbank
|
16 |
-
|
17 |
-
import diffusers
|
18 |
-
from diffusers.utils.torch_utils import randn_tensor
|
19 |
-
from diffusers import DDPMScheduler
|
20 |
-
from models.transformer_2d_flow import Transformer2DModel
|
21 |
-
from transformers import AutoFeatureExtractor, Wav2Vec2BertModel,HubertModel
|
22 |
-
# from tools.get_mulan import get_mulan
|
23 |
-
from third_party.wespeaker.extract_embd import XVECModel
|
24 |
-
# from libs.rvq2 import RVQEmbedding
|
25 |
-
from libs.rvq.descript_quantize3_4layer_freezelayer1 import ResidualVectorQuantize
|
26 |
-
|
27 |
-
from models_gpt.models.gpt2_rope2_time_new_correct_mask_noncasual_reflow import GPT2Model
|
28 |
-
from models_gpt.models.gpt2_config import GPT2Config
|
29 |
-
|
30 |
-
from torch.cuda.amp import autocast
|
31 |
-
|
32 |
-
|
33 |
-
from our_MERT_BESTRQ.test import load_model
|
34 |
-
|
35 |
-
class HubertModelWithFinalProj(HubertModel):
|
36 |
-
def __init__(self, config):
|
37 |
-
super().__init__(config)
|
38 |
-
|
39 |
-
# The final projection layer is only used for backward compatibility.
|
40 |
-
# Following https://github.com/auspicious3000/contentvec/issues/6
|
41 |
-
# Remove this layer is necessary to achieve the desired outcome.
|
42 |
-
print("hidden_size:",config.hidden_size)
|
43 |
-
print("classifier_proj_size:",config.classifier_proj_size)
|
44 |
-
self.final_proj = nn.Linear(config.hidden_size, config.classifier_proj_size)
|
45 |
-
|
46 |
-
|
47 |
-
class SampleProcessor(torch.nn.Module):
|
48 |
-
def project_sample(self, x: torch.Tensor):
|
49 |
-
"""Project the original sample to the 'space' where the diffusion will happen."""
|
50 |
-
"""Project back from diffusion space to the actual sample space."""
|
51 |
-
return z
|
52 |
-
|
53 |
-
class Feature1DProcessor(SampleProcessor):
|
54 |
-
def __init__(self, dim: int = 100, power_std = 1., \
|
55 |
-
num_samples: int = 100_000, cal_num_frames: int = 600):
|
56 |
-
super().__init__()
|
57 |
-
|
58 |
-
self.num_samples = num_samples
|
59 |
-
self.dim = dim
|
60 |
-
self.power_std = power_std
|
61 |
-
self.cal_num_frames = cal_num_frames
|
62 |
-
self.register_buffer('counts', torch.zeros(1))
|
63 |
-
self.register_buffer('sum_x', torch.zeros(dim))
|
64 |
-
self.register_buffer('sum_x2', torch.zeros(dim))
|
65 |
-
self.register_buffer('sum_target_x2', torch.zeros(dim))
|
66 |
-
self.counts: torch.Tensor
|
67 |
-
self.sum_x: torch.Tensor
|
68 |
-
self.sum_x2: torch.Tensor
|
69 |
-
|
70 |
-
@property
|
71 |
-
def mean(self):
|
72 |
-
mean = self.sum_x / self.counts
|
73 |
-
if(self.counts < 10):
|
74 |
-
mean = torch.zeros_like(mean)
|
75 |
-
return mean
|
76 |
-
|
77 |
-
@property
|
78 |
-
def std(self):
|
79 |
-
std = (self.sum_x2 / self.counts - self.mean**2).clamp(min=0).sqrt()
|
80 |
-
if(self.counts < 10):
|
81 |
-
std = torch.ones_like(std)
|
82 |
-
return std
|
83 |
-
|
84 |
-
@property
|
85 |
-
def target_std(self):
|
86 |
-
return 1
|
87 |
-
|
88 |
-
def project_sample(self, x: torch.Tensor):
|
89 |
-
assert x.dim() == 3
|
90 |
-
if self.counts.item() < self.num_samples:
|
91 |
-
self.counts += len(x)
|
92 |
-
self.sum_x += x[:,:,0:self.cal_num_frames].mean(dim=(2,)).sum(dim=0)
|
93 |
-
self.sum_x2 += x[:,:,0:self.cal_num_frames].pow(2).mean(dim=(2,)).sum(dim=0)
|
94 |
-
rescale = (self.target_std / self.std.clamp(min=1e-12)) ** self.power_std # same output size
|
95 |
-
x = (x - self.mean.view(1, -1, 1)) * rescale.view(1, -1, 1)
|
96 |
-
return x
|
97 |
-
|
98 |
-
def return_sample(self, x: torch.Tensor):
|
99 |
-
assert x.dim() == 3
|
100 |
-
rescale = (self.std / self.target_std) ** self.power_std
|
101 |
-
# print(rescale, self.mean)
|
102 |
-
x = x * rescale.view(1, -1, 1) + self.mean.view(1, -1, 1)
|
103 |
-
return x
|
104 |
-
|
105 |
-
def pad_or_tunc_tolen(prior_text_encoder_hidden_states, prior_text_mask, prior_prompt_embeds, len_size=77):
|
106 |
-
if(prior_text_encoder_hidden_states.shape[1]<len_size):
|
107 |
-
prior_text_encoder_hidden_states = torch.cat([prior_text_encoder_hidden_states, \
|
108 |
-
torch.zeros(prior_text_mask.shape[0], len_size-prior_text_mask.shape[1], \
|
109 |
-
prior_text_encoder_hidden_states.shape[2], device=prior_text_mask.device, \
|
110 |
-
dtype=prior_text_encoder_hidden_states.dtype)],1)
|
111 |
-
prior_text_mask = torch.cat([prior_text_mask, torch.zeros(prior_text_mask.shape[0], len_size-prior_text_mask.shape[1], device=prior_text_mask.device, dtype=prior_text_mask.dtype)],1)
|
112 |
-
else:
|
113 |
-
prior_text_encoder_hidden_states = prior_text_encoder_hidden_states[:,0:len_size]
|
114 |
-
prior_text_mask = prior_text_mask[:,0:len_size]
|
115 |
-
prior_text_encoder_hidden_states = prior_text_encoder_hidden_states.permute(0,2,1).contiguous()
|
116 |
-
return prior_text_encoder_hidden_states, prior_text_mask, prior_prompt_embeds
|
117 |
-
|
118 |
-
class BASECFM(torch.nn.Module, ABC):
|
119 |
-
def __init__(
|
120 |
-
self,
|
121 |
-
estimator,
|
122 |
-
mlp,
|
123 |
-
ssl_layer
|
124 |
-
):
|
125 |
-
super().__init__()
|
126 |
-
self.sigma_min = 1e-4
|
127 |
-
|
128 |
-
self.estimator = estimator
|
129 |
-
self.mlp = mlp
|
130 |
-
self.ssl_layer = ssl_layer
|
131 |
-
|
132 |
-
@torch.inference_mode()
|
133 |
-
def forward(self, mu, n_timesteps, temperature=1.0):
|
134 |
-
"""Forward diffusion
|
135 |
-
|
136 |
-
Args:
|
137 |
-
mu (torch.Tensor): output of encoder
|
138 |
-
shape: (batch_size, n_channels, mel_timesteps, n_feats)
|
139 |
-
n_timesteps (int): number of diffusion steps
|
140 |
-
temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
|
141 |
-
|
142 |
-
Returns:
|
143 |
-
sample: generated mel-spectrogram
|
144 |
-
shape: (batch_size, n_channels, mel_timesteps, n_feats)
|
145 |
-
"""
|
146 |
-
z = torch.randn_like(mu) * temperature
|
147 |
-
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device)
|
148 |
-
return self.solve_euler(z, t_span=t_span)
|
149 |
-
|
150 |
-
def solve_euler(self, x, latent_mask_input,incontext_x, incontext_length, t_span, mu,attention_mask, guidance_scale):
|
151 |
-
"""
|
152 |
-
Fixed euler solver for ODEs.
|
153 |
-
Args:
|
154 |
-
x (torch.Tensor): random noise
|
155 |
-
t_span (torch.Tensor): n_timesteps interpolated
|
156 |
-
shape: (n_timesteps + 1,)
|
157 |
-
mu (torch.Tensor): output of encoder
|
158 |
-
shape: (batch_size, n_channels, mel_timesteps, n_feats)
|
159 |
-
"""
|
160 |
-
t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
|
161 |
-
noise = x.clone()
|
162 |
-
|
163 |
-
# I am storing this because I can later plot it by putting a debugger here and saving it to a file
|
164 |
-
# Or in future might add like a return_all_steps flag
|
165 |
-
sol = []
|
166 |
-
|
167 |
-
for step in tqdm(range(1, len(t_span))):
|
168 |
-
print("incontext_x.shape:",incontext_x.shape)
|
169 |
-
print("noise.shape:",noise.shape)
|
170 |
-
print("t.shape:",t.shape)
|
171 |
-
x[:,0:incontext_length,:] = (1 - (1 - self.sigma_min) * t) * noise[:,0:incontext_length,:] + t * incontext_x[:,0:incontext_length,:]
|
172 |
-
if(guidance_scale > 1.0):
|
173 |
-
|
174 |
-
model_input = torch.cat([ \
|
175 |
-
torch.cat([latent_mask_input, latent_mask_input], 0), \
|
176 |
-
torch.cat([incontext_x, incontext_x], 0), \
|
177 |
-
torch.cat([torch.zeros_like(mu), mu], 0), \
|
178 |
-
torch.cat([x, x], 0), \
|
179 |
-
], 2)
|
180 |
-
timestep=t.unsqueeze(-1).repeat(2)
|
181 |
-
|
182 |
-
dphi_dt = self.estimator(inputs_embeds=model_input, attention_mask=attention_mask,time_step=timestep).last_hidden_state
|
183 |
-
dphi_dt_uncond, dhpi_dt_cond = dphi_dt.chunk(2,0)
|
184 |
-
dphi_dt = dphi_dt_uncond + guidance_scale * (dhpi_dt_cond - dphi_dt_uncond)
|
185 |
-
else:
|
186 |
-
model_input = torch.cat([latent_mask_input, incontext_x, mu, x], 2)
|
187 |
-
timestep=t.unsqueeze(-1)
|
188 |
-
dphi_dt = self.estimator(inputs_embeds=model_input, attention_mask=attention_mask,time_step=timestep).last_hidden_state
|
189 |
-
|
190 |
-
dphi_dt = dphi_dt[: ,:, -x.shape[2]:]
|
191 |
-
print("dphi_dt.shape:",dphi_dt.shape)
|
192 |
-
print("x.shape:",x.shape)
|
193 |
-
|
194 |
-
x = x + dt * dphi_dt
|
195 |
-
t = t + dt
|
196 |
-
sol.append(x)
|
197 |
-
if step < len(t_span) - 1:
|
198 |
-
dt = t_span[step + 1] - t
|
199 |
-
|
200 |
-
return sol[-1]
|
201 |
-
|
202 |
-
def projection_loss(self,hidden_proj, bestrq_emb):
|
203 |
-
bsz = hidden_proj.shape[0]
|
204 |
-
|
205 |
-
hidden_proj_normalized = F.normalize(hidden_proj, dim=-1)
|
206 |
-
bestrq_emb_normalized = F.normalize(bestrq_emb, dim=-1)
|
207 |
-
|
208 |
-
proj_loss = -(hidden_proj_normalized * bestrq_emb_normalized).sum(dim=-1)
|
209 |
-
proj_loss = 1+proj_loss.mean()
|
210 |
-
|
211 |
-
return proj_loss
|
212 |
-
|
213 |
-
def compute_loss(self, x1, mu, latent_masks,attention_mask,wav2vec_embeds, validation_mode=False):
|
214 |
-
"""Computes diffusion loss
|
215 |
-
|
216 |
-
Args:
|
217 |
-
x1 (torch.Tensor): Target
|
218 |
-
shape: (batch_size, n_channels, mel_timesteps, n_feats)
|
219 |
-
mu (torch.Tensor): output of encoder
|
220 |
-
shape: (batch_size, n_channels, mel_timesteps, n_feats)
|
221 |
-
|
222 |
-
Returns:
|
223 |
-
loss: conditional flow matching loss
|
224 |
-
y: conditional flow
|
225 |
-
shape: (batch_size, n_channels, mel_timesteps, n_feats)
|
226 |
-
"""
|
227 |
-
b = mu[0].shape[0]
|
228 |
-
len_x = x1.shape[2]
|
229 |
-
# random timestep
|
230 |
-
if(validation_mode):
|
231 |
-
t = torch.ones([b, 1, 1], device=mu[0].device, dtype=mu[0].dtype) * 0.5
|
232 |
-
else:
|
233 |
-
t = torch.rand([b, 1, 1], device=mu[0].device, dtype=mu[0].dtype)
|
234 |
-
# sample noise p(x_0)
|
235 |
-
z = torch.randn_like(x1)
|
236 |
-
|
237 |
-
y = (1 - (1 - self.sigma_min) * t) * z + t * x1
|
238 |
-
u = x1 - (1 - self.sigma_min) * z
|
239 |
-
# print("y.shape:",y.shape)
|
240 |
-
#self.unet(inputs_embeds=model_input, attention_mask=attention_mask,encoder_hidden_states=text_embedding,encoder_attention_mask=txt_attn_mask,time_step=timesteps).last_hidden_state
|
241 |
-
model_input = torch.cat([*mu,y], 2)
|
242 |
-
t=t.squeeze(-1).squeeze(-1)
|
243 |
-
# print("model_input.shape:",model_input.shape)
|
244 |
-
# print("attention_mask.shape:",attention_mask.shape)
|
245 |
-
out = self.estimator(inputs_embeds=model_input, attention_mask=attention_mask,time_step=t,output_hidden_states=True)
|
246 |
-
hidden_layer = out.hidden_states[self.ssl_layer]
|
247 |
-
hidden_proj = self.mlp(hidden_layer)
|
248 |
-
# print("hidden_proj.shape:",hidden_proj.shape)
|
249 |
-
# print("mert_emb.shape:",mert_emb.shape)
|
250 |
-
# exit()
|
251 |
-
|
252 |
-
|
253 |
-
out = out.last_hidden_state
|
254 |
-
|
255 |
-
out=out[:,:,-len_x:]
|
256 |
-
# out=self.proj_out(out)
|
257 |
-
|
258 |
-
weight = (latent_masks > 1.5).unsqueeze(-1).repeat(1, 1, out.shape[-1]).float() + (latent_masks < 0.5).unsqueeze(-1).repeat(1, 1, out.shape[-1]).float() * 0.01
|
259 |
-
# print("out.shape",out.shape)
|
260 |
-
# print("u.shape",u.shape)
|
261 |
-
loss_re = F.mse_loss(out * weight, u * weight, reduction="sum") / weight.sum()
|
262 |
-
# print("hidden_proj.shape:",hidden_proj.shape)
|
263 |
-
# print("wav2vec_embeds.shape:",wav2vec_embeds.shape)
|
264 |
-
loss_cos = self.projection_loss(hidden_proj, wav2vec_embeds)
|
265 |
-
loss = loss_re + loss_cos * 0.5
|
266 |
-
# print("loss_cos:",loss_cos,loss_cos.device)
|
267 |
-
print("loss:",loss,loss.device)
|
268 |
-
# exit()
|
269 |
-
return loss, loss_re, loss_cos
|
270 |
-
|
271 |
-
class PromptCondAudioDiffusion(nn.Module):
|
272 |
-
def __init__(
|
273 |
-
self,
|
274 |
-
num_channels,
|
275 |
-
unet_model_name=None,
|
276 |
-
unet_model_config_path=None,
|
277 |
-
snr_gamma=None,
|
278 |
-
hubert_layer=None,
|
279 |
-
ssl_layer=None,
|
280 |
-
uncondition=True,
|
281 |
-
out_paint=False,
|
282 |
-
):
|
283 |
-
super().__init__()
|
284 |
-
|
285 |
-
assert unet_model_name is not None or unet_model_config_path is not None, "Either UNet pretrain model name or a config file path is required"
|
286 |
-
|
287 |
-
self.unet_model_name = unet_model_name
|
288 |
-
self.unet_model_config_path = unet_model_config_path
|
289 |
-
self.snr_gamma = snr_gamma
|
290 |
-
self.uncondition = uncondition
|
291 |
-
self.num_channels = num_channels
|
292 |
-
self.hubert_layer = hubert_layer
|
293 |
-
self.ssl_layer = ssl_layer
|
294 |
-
|
295 |
-
# https://huggingface.co/docs/diffusers/v0.14.0/en/api/schedulers/overview
|
296 |
-
self.normfeat = Feature1DProcessor(dim=64)
|
297 |
-
|
298 |
-
self.sample_rate = 48000
|
299 |
-
self.num_samples_perseg = self.sample_rate * 20 // 1000
|
300 |
-
self.rsp48toclap = torchaudio.transforms.Resample(48000, 24000)
|
301 |
-
self.rsq48towav2vec = torchaudio.transforms.Resample(48000, 16000)
|
302 |
-
# self.wav2vec = Wav2Vec2BertModel.from_pretrained("facebook/w2v-bert-2.0", trust_remote_code=True)
|
303 |
-
# self.wav2vec_processor = AutoFeatureExtractor.from_pretrained("facebook/w2v-bert-2.0", trust_remote_code=True)
|
304 |
-
self.bestrq = load_model(
|
305 |
-
model_dir='path/to/our-MERT/mert_fairseq',
|
306 |
-
checkpoint_dir='checkpoint-120000.pt',
|
307 |
-
)
|
308 |
-
self.rsq48tobestrq = torchaudio.transforms.Resample(48000, 24000)
|
309 |
-
self.rsq48tohubert = torchaudio.transforms.Resample(48000, 16000)
|
310 |
-
for v in self.bestrq.parameters():v.requires_grad = False
|
311 |
-
self.rvq_bestrq_emb = ResidualVectorQuantize(input_dim = 1024, n_codebooks = 4, codebook_size = 16_384, codebook_dim = 32, quantizer_dropout = 0.0, stale_tolerance=200)
|
312 |
-
# for v in self.rvq_bestrq_emb.parameters():
|
313 |
-
# print(v)
|
314 |
-
freeze_parameters='quantizers.0'
|
315 |
-
for name, param in self.rvq_bestrq_emb.named_parameters():
|
316 |
-
if freeze_parameters in name:
|
317 |
-
param.requires_grad = False
|
318 |
-
print("Freezing RVQ parameters:", name)
|
319 |
-
self.hubert = HubertModelWithFinalProj.from_pretrained("huggingface_cache/models--lengyue233--content-vec-best/snapshots/c0b9ba13db21beaa4053faae94c102ebe326fd68")
|
320 |
-
for v in self.hubert.parameters():v.requires_grad = False
|
321 |
-
self.zero_cond_embedding1 = nn.Parameter(torch.randn(32*32,))
|
322 |
-
# self.xvecmodel = XVECModel()
|
323 |
-
config = GPT2Config(n_positions=1000,n_layer=39,n_head=30,n_embd=1200)
|
324 |
-
unet = GPT2Model(config)
|
325 |
-
mlp = nn.Sequential(
|
326 |
-
nn.Linear(1200, 1024),
|
327 |
-
nn.SiLU(),
|
328 |
-
nn.Linear(1024, 1024),
|
329 |
-
nn.SiLU(),
|
330 |
-
nn.Linear(1024, 768)
|
331 |
-
)
|
332 |
-
self.set_from = "random"
|
333 |
-
self.cfm_wrapper = BASECFM(unet, mlp,self.ssl_layer)
|
334 |
-
self.mask_emb = torch.nn.Embedding(3, 48)
|
335 |
-
print("Transformer initialized from pretrain.")
|
336 |
-
torch.cuda.empty_cache()
|
337 |
-
# self.unet.set_attn_processor(AttnProcessor2_0())
|
338 |
-
# self.unet.set_use_memory_efficient_attention_xformers(True)
|
339 |
-
|
340 |
-
# self.start_embedding = nn.Parameter(torch.randn(1,1024))
|
341 |
-
# self.end_embedding = nn.Parameter(torch.randn(1,1024))
|
342 |
-
|
343 |
-
def compute_snr(self, timesteps):
|
344 |
-
"""
|
345 |
-
Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
|
346 |
-
"""
|
347 |
-
alphas_cumprod = self.noise_scheduler.alphas_cumprod
|
348 |
-
sqrt_alphas_cumprod = alphas_cumprod**0.5
|
349 |
-
sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
|
350 |
-
|
351 |
-
# Expand the tensors.
|
352 |
-
# Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026
|
353 |
-
sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
|
354 |
-
while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape):
|
355 |
-
sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
|
356 |
-
alpha = sqrt_alphas_cumprod.expand(timesteps.shape)
|
357 |
-
|
358 |
-
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
|
359 |
-
while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape):
|
360 |
-
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None]
|
361 |
-
sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape)
|
362 |
-
|
363 |
-
# Compute SNR.
|
364 |
-
snr = (alpha / sigma) ** 2
|
365 |
-
return snr
|
366 |
-
|
367 |
-
def preprocess_audio(self, input_audios, threshold=0.9):
|
368 |
-
assert len(input_audios.shape) == 2, input_audios.shape
|
369 |
-
norm_value = torch.ones_like(input_audios[:,0])
|
370 |
-
max_volume = input_audios.abs().max(dim=-1)[0]
|
371 |
-
norm_value[max_volume>threshold] = max_volume[max_volume>threshold] / threshold
|
372 |
-
return input_audios/norm_value.unsqueeze(-1)
|
373 |
-
|
374 |
-
def extract_wav2vec_embeds(self, input_audios,output_len):
|
375 |
-
wav2vec_stride = 2
|
376 |
-
|
377 |
-
wav2vec_embeds = self.hubert(self.rsq48tohubert(input_audios), output_hidden_states=True).hidden_states # 1, 4096, 1024
|
378 |
-
# print(wav2vec_embeds)
|
379 |
-
# print("audio.shape:",input_audios.shape)
|
380 |
-
wav2vec_embeds_last=wav2vec_embeds[self.hubert_layer]
|
381 |
-
# print("wav2vec_embeds_last.shape:",wav2vec_embeds_last.shape)
|
382 |
-
wav2vec_embeds_last=torch.nn.functional.interpolate(wav2vec_embeds_last.permute(0, 2, 1), size=output_len, mode='linear', align_corners=False).permute(0, 2, 1)
|
383 |
-
return wav2vec_embeds_last
|
384 |
-
|
385 |
-
def extract_mert_embeds(self, input_audios):
|
386 |
-
prompt_stride = 3
|
387 |
-
inputs = self.clap_embd_extractor.mulan.audio.processor(self.rsp48toclap(input_audios), sampling_rate=self.clap_embd_extractor.mulan.audio.sr, return_tensors="pt")
|
388 |
-
input_values = inputs['input_values'].squeeze(0).to(input_audios.device, dtype = input_audios.dtype)
|
389 |
-
prompt_embeds = self.clap_embd_extractor.mulan.audio.model(input_values, output_hidden_states=True).hidden_states # batch_size, Time steps, 1024
|
390 |
-
mert_emb= prompt_embeds[-1]
|
391 |
-
mert_emb = torch.nn.functional.interpolate(mert_emb.permute(0, 2, 1), size=500, mode='linear', align_corners=False).permute(0, 2, 1)
|
392 |
-
|
393 |
-
return mert_emb
|
394 |
-
|
395 |
-
def extract_bestrq_embeds(self, input_audio_0,input_audio_1,layer):
|
396 |
-
self.bestrq.eval()
|
397 |
-
# print("audio shape:",input_audio_0.shape)
|
398 |
-
input_wav_mean = (input_audio_0 + input_audio_1) / 2.0
|
399 |
-
# print("input_wav_mean.shape:",input_wav_mean.shape)
|
400 |
-
# input_wav_mean = torch.randn(2,1720320*2).to(input_audio_0.device)
|
401 |
-
input_wav_mean = self.bestrq(self.rsq48tobestrq(input_wav_mean), features_only = True)
|
402 |
-
layer_results = input_wav_mean['layer_results']
|
403 |
-
# print("layer_results.shape:",layer_results[layer].shape)
|
404 |
-
bestrq_emb = layer_results[layer]
|
405 |
-
bestrq_emb = bestrq_emb.permute(0,2,1).contiguous()
|
406 |
-
#[b,t,1024] t=t/960
|
407 |
-
#35.84s->batch,896,1024
|
408 |
-
return bestrq_emb
|
409 |
-
|
410 |
-
|
411 |
-
def extract_spk_embeds(self, input_audios):
|
412 |
-
spk_embeds = self.xvecmodel(self.rsq48towav2vec(input_audios))
|
413 |
-
spk_embeds = self.spk_linear(spk_embeds).reshape(spk_embeds.shape[0], 16, 1, 32)
|
414 |
-
return spk_embeds
|
415 |
-
|
416 |
-
def extract_lyric_feats(self, lyric):
|
417 |
-
with torch.no_grad():
|
418 |
-
try:
|
419 |
-
text_encoder_hidden_states, text_mask, text_prompt_embeds = self.clap_embd_extractor(texts = lyric, return_one=False)
|
420 |
-
except:
|
421 |
-
text_encoder_hidden_states, text_mask, text_prompt_embeds = self.clap_embd_extractor(texts = [""] * len(lyric), return_one=False)
|
422 |
-
text_encoder_hidden_states = text_encoder_hidden_states.to(self.device)
|
423 |
-
text_mask = text_mask.to(self.device)
|
424 |
-
text_encoder_hidden_states, text_mask, text_prompt_embeds = \
|
425 |
-
pad_or_tunc_tolen(text_encoder_hidden_states, text_mask, text_prompt_embeds)
|
426 |
-
text_encoder_hidden_states = text_encoder_hidden_states.permute(0,2,1).contiguous()
|
427 |
-
return text_encoder_hidden_states, text_mask
|
428 |
-
|
429 |
-
def extract_energy_bar(self, input_audios):
|
430 |
-
if(input_audios.shape[-1] % self.num_samples_perseg > 0):
|
431 |
-
energy_bar = input_audios[:,:-1 * (input_audios.shape[-1] % self.num_samples_perseg)].reshape(input_audios.shape[0],-1,self.num_samples_perseg)
|
432 |
-
else:
|
433 |
-
energy_bar = input_audios.reshape(input_audios.shape[0],-1,self.num_samples_perseg)
|
434 |
-
energy_bar = (energy_bar.pow(2.0).mean(-1).sqrt() + 1e-6).log10() * 20 # B T
|
435 |
-
energy_bar = (energy_bar / 2.0 + 16).clamp(0,16).int()
|
436 |
-
energy_embedding = self.energy_embedding(energy_bar)
|
437 |
-
energy_embedding = energy_embedding.view(energy_embedding.shape[0], energy_embedding.shape[1] // 2, 2, 32).reshape(energy_embedding.shape[0], energy_embedding.shape[1] // 2, 64).permute(0,2,1) # b 128 t
|
438 |
-
return energy_embedding
|
439 |
-
|
440 |
-
def forward(self, input_audios, lyric, latents, latent_masks, validation_mode=False, \
|
441 |
-
additional_feats = ['spk', 'lyric'], \
|
442 |
-
train_rvq=True, train_ssl=False,layer=5):
|
443 |
-
if not hasattr(self,"device"):
|
444 |
-
self.device = input_audios.device
|
445 |
-
if not hasattr(self,"dtype"):
|
446 |
-
self.dtype = input_audios.dtype
|
447 |
-
device = self.device
|
448 |
-
input_audio_0 = input_audios[:,0,:]
|
449 |
-
input_audio_1 = input_audios[:,1,:]
|
450 |
-
input_audio_0 = self.preprocess_audio(input_audio_0)
|
451 |
-
input_audio_1 = self.preprocess_audio(input_audio_1)
|
452 |
-
input_audios_wav2vec = (input_audio_0 + input_audio_1) / 2.0
|
453 |
-
# energy_embedding = self.extract_energy_bar(input_audios)
|
454 |
-
# print("energy_embedding.shape:",energy_embedding.shape)
|
455 |
-
# with autocast(enabled=False):
|
456 |
-
if(train_ssl):
|
457 |
-
self.wav2vec.train()
|
458 |
-
wav2vec_embeds = self.extract_wav2vec_embeds(input_audios)
|
459 |
-
self.clap_embd_extractor.train()
|
460 |
-
prompt_embeds = self.extract_mert_embeds(input_audios)
|
461 |
-
if('spk' in additional_feats):
|
462 |
-
self.xvecmodel.train()
|
463 |
-
spk_embeds = self.extract_spk_embeds(input_audios).repeat(1,1,prompt_embeds.shape[-1]//2,1)
|
464 |
-
else:
|
465 |
-
with torch.no_grad():
|
466 |
-
with autocast(enabled=False):
|
467 |
-
bestrq_emb = self.extract_bestrq_embeds(input_audio_0,input_audio_1,layer)
|
468 |
-
# mert_emb = self.extract_mert_embeds(input_audios_mert)
|
469 |
-
|
470 |
-
wav2vec_embeds = self.extract_wav2vec_embeds(input_audios_wav2vec,bestrq_emb.shape[2])
|
471 |
-
|
472 |
-
bestrq_emb = bestrq_emb.detach()
|
473 |
-
if('lyric' in additional_feats):
|
474 |
-
text_encoder_hidden_states, text_mask = self.extract_lyric_feats(lyric)
|
475 |
-
else:
|
476 |
-
text_encoder_hidden_states, text_mask = None, None
|
477 |
-
|
478 |
-
|
479 |
-
if(train_rvq):
|
480 |
-
random_num=random.random()
|
481 |
-
if(random_num<0.6):
|
482 |
-
rvq_layer = 1
|
483 |
-
elif(random_num<0.8):
|
484 |
-
rvq_layer = 2
|
485 |
-
else:
|
486 |
-
rvq_layer = 4
|
487 |
-
quantized_bestrq_emb, _, _, commitment_loss_bestrq_emb, codebook_loss_bestrq_emb,_ = self.rvq_bestrq_emb(bestrq_emb,n_quantizers=rvq_layer) # b,d,t
|
488 |
-
else:
|
489 |
-
bestrq_emb = bestrq_emb.float()
|
490 |
-
self.rvq_bestrq_emb.eval()
|
491 |
-
# with autocast(enabled=False):
|
492 |
-
quantized_bestrq_emb, _, _, commitment_loss_bestrq_emb, codebook_loss_bestrq_emb,_ = self.rvq_bestrq_emb(bestrq_emb) # b,d,t
|
493 |
-
commitment_loss_bestrq_emb = commitment_loss_bestrq_emb.detach()
|
494 |
-
codebook_loss_bestrq_emb = codebook_loss_bestrq_emb.detach()
|
495 |
-
quantized_bestrq_emb = quantized_bestrq_emb.detach()
|
496 |
-
|
497 |
-
commitment_loss = commitment_loss_bestrq_emb
|
498 |
-
codebook_loss = codebook_loss_bestrq_emb
|
499 |
-
|
500 |
-
|
501 |
-
alpha=1
|
502 |
-
quantized_bestrq_emb = quantized_bestrq_emb * alpha + bestrq_emb * (1-alpha)
|
503 |
-
|
504 |
-
# print("quantized_bestrq_emb.shape:",quantized_bestrq_emb.shape)
|
505 |
-
# print("latent_masks.shape:",latent_masks.shape)
|
506 |
-
# quantized_bestrq_emb = torch.nn.functional.interpolate(quantized_bestrq_emb, size=(int(quantized_bestrq_emb.shape[-1]/999*937),), mode='linear', align_corners=True)
|
507 |
-
|
508 |
-
|
509 |
-
|
510 |
-
scenario = np.random.choice(['start_seg', 'other_seg'])
|
511 |
-
if(scenario == 'other_seg'):
|
512 |
-
for binx in range(input_audios.shape[0]):
|
513 |
-
# latent_masks[binx,0:64] = 1
|
514 |
-
latent_masks[binx,0:random.randint(64,128)] = 1
|
515 |
-
quantized_bestrq_emb = quantized_bestrq_emb.permute(0,2,1).contiguous()
|
516 |
-
# print("quantized_bestrq_emb.shape:",quantized_bestrq_emb.shape)
|
517 |
-
# print("quantized_bestrq_emb1.shape:",quantized_bestrq_emb.shape)
|
518 |
-
# print("latent_masks.shape:",latent_masks.shape)
|
519 |
-
quantized_bestrq_emb = (latent_masks > 0.5).unsqueeze(-1) * quantized_bestrq_emb \
|
520 |
-
+ (latent_masks < 0.5).unsqueeze(-1) * self.zero_cond_embedding1.reshape(1,1,1024)
|
521 |
-
|
522 |
-
|
523 |
-
|
524 |
-
|
525 |
-
if self.uncondition:
|
526 |
-
mask_indices = [k for k in range(quantized_bestrq_emb.shape[0]) if random.random() < 0.1]
|
527 |
-
if len(mask_indices) > 0:
|
528 |
-
quantized_bestrq_emb[mask_indices] = 0
|
529 |
-
# print("latents.shape:",latents.shape)
|
530 |
-
latents = latents.permute(0,2,1).contiguous()
|
531 |
-
latents = self.normfeat.project_sample(latents)
|
532 |
-
latents = latents.permute(0,2,1).contiguous()
|
533 |
-
incontext_latents = latents * ((latent_masks > 0.5) * (latent_masks < 1.5)).unsqueeze(-1).float()
|
534 |
-
attention_mask=(latent_masks > 0.5)
|
535 |
-
B, L = attention_mask.size()
|
536 |
-
attention_mask = attention_mask.view(B, 1, L)
|
537 |
-
attention_mask = attention_mask * attention_mask.transpose(-1, -2)
|
538 |
-
attention_mask = attention_mask.unsqueeze(1)
|
539 |
-
# print("incontext_latents.shape:",incontext_latents.shape)
|
540 |
-
# print("quantized_bestrq_emb.shape:",quantized_bestrq_emb.shape)
|
541 |
-
latent_mask_input = self.mask_emb(latent_masks)
|
542 |
-
#64+48+64+1024
|
543 |
-
loss,loss_re, loss_cos = self.cfm_wrapper.compute_loss(latents, [latent_mask_input,incontext_latents, quantized_bestrq_emb], latent_masks,attention_mask,wav2vec_embeds, validation_mode=validation_mode)
|
544 |
-
return loss,loss_re, loss_cos, commitment_loss.mean(), codebook_loss.mean()
|
545 |
-
|
546 |
-
def init_device_dtype(self, device, dtype):
|
547 |
-
self.device = device
|
548 |
-
self.dtype = dtype
|
549 |
-
|
550 |
-
@torch.no_grad()
|
551 |
-
def fetch_codes(self, input_audios, additional_feats,layer,rvq_num=1):
|
552 |
-
input_audio_0 = input_audios[[0],:]
|
553 |
-
input_audio_1 = input_audios[[1],:]
|
554 |
-
input_audio_0 = self.preprocess_audio(input_audio_0)
|
555 |
-
input_audio_1 = self.preprocess_audio(input_audio_1)
|
556 |
-
|
557 |
-
self.bestrq.eval()
|
558 |
-
|
559 |
-
# bestrq_middle,bestrq_last = self.extract_bestrq_embeds(input_audios)
|
560 |
-
# bestrq_middle = bestrq_middle.detach()
|
561 |
-
# bestrq_last = bestrq_last.detach()
|
562 |
-
bestrq_emb = self.extract_bestrq_embeds(input_audio_0,input_audio_1,layer)
|
563 |
-
bestrq_emb = bestrq_emb.detach()
|
564 |
-
|
565 |
-
# self.rvq_bestrq_middle.eval()
|
566 |
-
# quantized_bestrq_middle, codes_bestrq_middle, *_ = self.rvq_bestrq_middle(bestrq_middle) # b,d,t
|
567 |
-
# self.rvq_bestrq_last.eval()
|
568 |
-
# quantized_bestrq_last, codes_bestrq_last, *_ = self.rvq_bestrq_last(bestrq_last) # b,d,t
|
569 |
-
|
570 |
-
self.rvq_bestrq_emb.eval()
|
571 |
-
quantized_bestrq_emb, codes_bestrq_emb, *_ = self.rvq_bestrq_emb(bestrq_emb)
|
572 |
-
codes_bestrq_emb = codes_bestrq_emb[:,:rvq_num,:]
|
573 |
-
# print("codes_bestrq_emb.shape:",codes_bestrq_emb.shape)
|
574 |
-
# exit()
|
575 |
-
|
576 |
-
|
577 |
-
if('spk' in additional_feats):
|
578 |
-
self.xvecmodel.eval()
|
579 |
-
spk_embeds = self.extract_spk_embeds(input_audios)
|
580 |
-
else:
|
581 |
-
spk_embeds = None
|
582 |
-
|
583 |
-
# return [codes_prompt, codes_wav2vec], [prompt_embeds, wav2vec_embeds], spk_embeds
|
584 |
-
# return [codes_prompt_7, codes_prompt_13, codes_prompt_20, codes_wav2vec_half, codes_wav2vec_last], [prompt_embeds_7, prompt_embeds_13, prompt_embeds_20, wav2vec_embeds_half, wav2vec_embeds_last], spk_embeds
|
585 |
-
# return [codes_bestrq_middle, codes_bestrq_last], [bestrq_middle, bestrq_last], spk_embeds
|
586 |
-
return [codes_bestrq_emb], [bestrq_emb], spk_embeds
|
587 |
-
# return [codes_prompt_13, codes_wav2vec_last], [prompt_embeds_13, wav2vec_embeds_last], spk_embeds
|
588 |
-
|
589 |
-
@torch.no_grad()
|
590 |
-
def fetch_codes_batch(self, input_audios, additional_feats,layer,rvq_num=1):
|
591 |
-
input_audio_0 = input_audios[:,0,:]
|
592 |
-
input_audio_1 = input_audios[:,1,:]
|
593 |
-
input_audio_0 = self.preprocess_audio(input_audio_0)
|
594 |
-
input_audio_1 = self.preprocess_audio(input_audio_1)
|
595 |
-
|
596 |
-
self.bestrq.eval()
|
597 |
-
|
598 |
-
# bestrq_middle,bestrq_last = self.extract_bestrq_embeds(input_audios)
|
599 |
-
# bestrq_middle = bestrq_middle.detach()
|
600 |
-
# bestrq_last = bestrq_last.detach()
|
601 |
-
bestrq_emb = self.extract_bestrq_embeds(input_audio_0,input_audio_1,layer)
|
602 |
-
bestrq_emb = bestrq_emb.detach()
|
603 |
-
|
604 |
-
# self.rvq_bestrq_middle.eval()
|
605 |
-
# quantized_bestrq_middle, codes_bestrq_middle, *_ = self.rvq_bestrq_middle(bestrq_middle) # b,d,t
|
606 |
-
# self.rvq_bestrq_last.eval()
|
607 |
-
# quantized_bestrq_last, codes_bestrq_last, *_ = self.rvq_bestrq_last(bestrq_last) # b,d,t
|
608 |
-
|
609 |
-
self.rvq_bestrq_emb.eval()
|
610 |
-
quantized_bestrq_emb, codes_bestrq_emb, *_ = self.rvq_bestrq_emb(bestrq_emb)
|
611 |
-
# print("codes_bestrq_emb.shape:",codes_bestrq_emb.shape)
|
612 |
-
codes_bestrq_emb = codes_bestrq_emb[:,:rvq_num,:]
|
613 |
-
# print("codes_bestrq_emb.shape:",codes_bestrq_emb.shape)
|
614 |
-
# exit()
|
615 |
-
|
616 |
-
|
617 |
-
if('spk' in additional_feats):
|
618 |
-
self.xvecmodel.eval()
|
619 |
-
spk_embeds = self.extract_spk_embeds(input_audios)
|
620 |
-
else:
|
621 |
-
spk_embeds = None
|
622 |
-
|
623 |
-
# return [codes_prompt, codes_wav2vec], [prompt_embeds, wav2vec_embeds], spk_embeds
|
624 |
-
# return [codes_prompt_7, codes_prompt_13, codes_prompt_20, codes_wav2vec_half, codes_wav2vec_last], [prompt_embeds_7, prompt_embeds_13, prompt_embeds_20, wav2vec_embeds_half, wav2vec_embeds_last], spk_embeds
|
625 |
-
# return [codes_bestrq_middle, codes_bestrq_last], [bestrq_middle, bestrq_last], spk_embeds
|
626 |
-
return [codes_bestrq_emb], [bestrq_emb], spk_embeds
|
627 |
-
# return [codes_prompt_13, codes_wav2vec_last], [prompt_embeds_13, wav2vec_embeds_last], spk_embeds
|
628 |
-
|
629 |
-
@torch.no_grad()
|
630 |
-
def fetch_codes_batch_ds(self, input_audios, additional_feats, layer, rvq_num=1, ds=250):
|
631 |
-
input_audio_0 = input_audios[:,0,:]
|
632 |
-
input_audio_1 = input_audios[:,1,:]
|
633 |
-
input_audio_0 = self.preprocess_audio(input_audio_0)
|
634 |
-
input_audio_1 = self.preprocess_audio(input_audio_1)
|
635 |
-
|
636 |
-
self.bestrq.eval()
|
637 |
-
|
638 |
-
# bestrq_middle,bestrq_last = self.extract_bestrq_embeds(input_audios)
|
639 |
-
# bestrq_middle = bestrq_middle.detach()
|
640 |
-
# bestrq_last = bestrq_last.detach()
|
641 |
-
bestrq_emb = self.extract_bestrq_embeds(input_audio_0,input_audio_1,layer)
|
642 |
-
bestrq_emb = bestrq_emb.detach()
|
643 |
-
|
644 |
-
# self.rvq_bestrq_middle.eval()
|
645 |
-
# quantized_bestrq_middle, codes_bestrq_middle, *_ = self.rvq_bestrq_middle(bestrq_middle) # b,d,t
|
646 |
-
# self.rvq_bestrq_last.eval()
|
647 |
-
# quantized_bestrq_last, codes_bestrq_last, *_ = self.rvq_bestrq_last(bestrq_last) # b,d,t
|
648 |
-
|
649 |
-
self.rvq_bestrq_emb.eval()
|
650 |
-
bestrq_emb = torch.nn.functional.avg_pool1d(bestrq_emb, kernel_size=ds, stride=ds)
|
651 |
-
quantized_bestrq_emb, codes_bestrq_emb, *_ = self.rvq_bestrq_emb(bestrq_emb)
|
652 |
-
# print("codes_bestrq_emb.shape:",codes_bestrq_emb.shape)
|
653 |
-
codes_bestrq_emb = codes_bestrq_emb[:,:rvq_num,:]
|
654 |
-
# print("codes_bestrq_emb.shape:",codes_bestrq_emb.shape)
|
655 |
-
# exit()
|
656 |
-
|
657 |
-
|
658 |
-
if('spk' in additional_feats):
|
659 |
-
self.xvecmodel.eval()
|
660 |
-
spk_embeds = self.extract_spk_embeds(input_audios)
|
661 |
-
else:
|
662 |
-
spk_embeds = None
|
663 |
-
|
664 |
-
# return [codes_prompt, codes_wav2vec], [prompt_embeds, wav2vec_embeds], spk_embeds
|
665 |
-
# return [codes_prompt_7, codes_prompt_13, codes_prompt_20, codes_wav2vec_half, codes_wav2vec_last], [prompt_embeds_7, prompt_embeds_13, prompt_embeds_20, wav2vec_embeds_half, wav2vec_embeds_last], spk_embeds
|
666 |
-
# return [codes_bestrq_middle, codes_bestrq_last], [bestrq_middle, bestrq_last], spk_embeds
|
667 |
-
return [codes_bestrq_emb], [bestrq_emb], spk_embeds
|
668 |
-
# return [codes_prompt_13, codes_wav2vec_last], [prompt_embeds_13, wav2vec_embeds_last], spk_embeds
|
669 |
-
|
670 |
-
@torch.no_grad()
|
671 |
-
def inference_codes(self, codes, spk_embeds, true_latents, latent_length, additional_feats, incontext_length=127,
|
672 |
-
guidance_scale=2, num_steps=20,
|
673 |
-
disable_progress=True, scenario='start_seg'):
|
674 |
-
classifier_free_guidance = guidance_scale > 1.0
|
675 |
-
device = self.device
|
676 |
-
dtype = self.dtype
|
677 |
-
# codes_bestrq_middle, codes_bestrq_last = codes
|
678 |
-
codes_bestrq_emb = codes[0]
|
679 |
-
|
680 |
-
|
681 |
-
batch_size = codes_bestrq_emb.shape[0]
|
682 |
-
|
683 |
-
|
684 |
-
quantized_bestrq_emb,_,_=self.rvq_bestrq_emb.from_codes(codes_bestrq_emb)
|
685 |
-
# quantized_bestrq_emb = torch.nn.functional.interpolate(quantized_bestrq_emb, size=(int(quantized_bestrq_emb.shape[-1]/999*937),), mode='linear', align_corners=True)
|
686 |
-
quantized_bestrq_emb = quantized_bestrq_emb.permute(0,2,1).contiguous()
|
687 |
-
print("quantized_bestrq_emb.shape:",quantized_bestrq_emb.shape)
|
688 |
-
# quantized_bestrq_emb = torch.nn.functional.interpolate(quantized_bestrq_emb, size=(int(quantized_bestrq_emb.shape[-1]/999*937),), mode='linear', align_corners=True)
|
689 |
-
|
690 |
-
|
691 |
-
|
692 |
-
|
693 |
-
if('spk' in additional_feats):
|
694 |
-
spk_embeds = spk_embeds.repeat(1,1,quantized_bestrq_emb.shape[-2],1).detach()
|
695 |
-
|
696 |
-
num_frames = quantized_bestrq_emb.shape[1]
|
697 |
-
|
698 |
-
num_channels_latents = self.num_channels
|
699 |
-
shape = (batch_size, num_frames, 64)
|
700 |
-
latents = randn_tensor(shape, generator=None, device=device, dtype=dtype)
|
701 |
-
|
702 |
-
|
703 |
-
|
704 |
-
latent_masks = torch.zeros(latents.shape[0], latents.shape[1], dtype=torch.int64, device=latents.device)
|
705 |
-
latent_masks[:,0:latent_length] = 2
|
706 |
-
if(scenario=='other_seg'):
|
707 |
-
latent_masks[:,0:incontext_length] = 1
|
708 |
-
|
709 |
-
|
710 |
-
|
711 |
-
quantized_bestrq_emb = (latent_masks > 0.5).unsqueeze(-1) * quantized_bestrq_emb \
|
712 |
-
+ (latent_masks < 0.5).unsqueeze(-1) * self.zero_cond_embedding1.reshape(1,1,1024)
|
713 |
-
true_latents = true_latents.permute(0,2,1).contiguous()
|
714 |
-
true_latents = self.normfeat.project_sample(true_latents)
|
715 |
-
true_latents = true_latents.permute(0,2,1).contiguous()
|
716 |
-
incontext_latents = true_latents * ((latent_masks > 0.5) * (latent_masks < 1.5)).unsqueeze(-1).float()
|
717 |
-
incontext_length = ((latent_masks > 0.5) * (latent_masks < 1.5)).sum(-1)[0]
|
718 |
-
|
719 |
-
|
720 |
-
attention_mask=(latent_masks > 0.5)
|
721 |
-
B, L = attention_mask.size()
|
722 |
-
attention_mask = attention_mask.view(B, 1, L)
|
723 |
-
attention_mask = attention_mask * attention_mask.transpose(-1, -2)
|
724 |
-
attention_mask = attention_mask.unsqueeze(1)
|
725 |
-
latent_mask_input = self.mask_emb(latent_masks)
|
726 |
-
|
727 |
-
if('spk' in additional_feats):
|
728 |
-
# additional_model_input = torch.cat([quantized_bestrq_middle, quantized_bestrq_last, spk_embeds],1)
|
729 |
-
additional_model_input = torch.cat([quantized_bestrq_emb, spk_embeds],1)
|
730 |
-
else:
|
731 |
-
# additional_model_input = torch.cat([quantized_bestrq_middle, quantized_bestrq_last],1)
|
732 |
-
additional_model_input = torch.cat([quantized_bestrq_emb],1)
|
733 |
-
|
734 |
-
temperature = 1.0
|
735 |
-
t_span = torch.linspace(0, 1, num_steps + 1, device=quantized_bestrq_emb.device)
|
736 |
-
latents = self.cfm_wrapper.solve_euler(latents * temperature, latent_mask_input,incontext_latents, incontext_length, t_span, additional_model_input,attention_mask, guidance_scale)
|
737 |
-
|
738 |
-
latents[:,0:incontext_length,:] = incontext_latents[:,0:incontext_length,:]
|
739 |
-
latents = latents.permute(0,2,1).contiguous()
|
740 |
-
latents = self.normfeat.return_sample(latents)
|
741 |
-
# latents = latents.permute(0,2,1).contiguous()
|
742 |
-
return latents
|
743 |
-
|
744 |
-
@torch.no_grad()
|
745 |
-
def inference(self, input_audios, lyric, true_latents, latent_length, additional_feats, guidance_scale=2, num_steps=20,
|
746 |
-
disable_progress=True,layer=5,scenario='start_seg',rvq_num=1):
|
747 |
-
codes, embeds, spk_embeds = self.fetch_codes(input_audios, additional_feats,layer,rvq_num)
|
748 |
-
|
749 |
-
latents = self.inference_codes(codes, spk_embeds, true_latents, latent_length, additional_feats, \
|
750 |
-
guidance_scale=guidance_scale, num_steps=num_steps, \
|
751 |
-
disable_progress=disable_progress,scenario=scenario)
|
752 |
-
return latents
|
753 |
-
|
754 |
-
@torch.no_grad()
|
755 |
-
def inference_rtf(self, input_audios, lyric, true_latents, latent_length, additional_feats, guidance_scale=2, num_steps=20,
|
756 |
-
disable_progress=True,layer=5,scenario='start_seg'):
|
757 |
-
codes, embeds, spk_embeds = self.fetch_codes(input_audios, additional_feats,layer)
|
758 |
-
import time
|
759 |
-
start = time.time()
|
760 |
-
latents = self.inference_codes(codes, spk_embeds, true_latents, latent_length, additional_feats, \
|
761 |
-
guidance_scale=guidance_scale, num_steps=num_steps, \
|
762 |
-
disable_progress=disable_progress,scenario=scenario)
|
763 |
-
return latents,time.time()-start
|
764 |
-
|
765 |
-
def prepare_latents(self, batch_size, num_frames, num_channels_latents, dtype, device):
|
766 |
-
divisor = 4
|
767 |
-
shape = (batch_size, num_channels_latents, num_frames, 32)
|
768 |
-
if(num_frames%divisor>0):
|
769 |
-
num_frames = round(num_frames/float(divisor))*divisor
|
770 |
-
shape = (batch_size, num_channels_latents, num_frames, 32)
|
771 |
-
latents = randn_tensor(shape, generator=None, device=device, dtype=dtype)
|
772 |
-
return latents
|
773 |
-
|
774 |
-
|
|
|
1 |
+
import yaml
|
2 |
+
import random
|
3 |
+
import inspect
|
4 |
+
import numpy as np
|
5 |
+
from tqdm import tqdm
|
6 |
+
import typing as tp
|
7 |
+
from abc import ABC
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
import torch.nn.functional as F
|
12 |
+
import torchaudio
|
13 |
+
|
14 |
+
from einops import repeat
|
15 |
+
from tools.torch_tools import wav_to_fbank
|
16 |
+
|
17 |
+
import diffusers
|
18 |
+
from diffusers.utils.torch_utils import randn_tensor
|
19 |
+
from diffusers import DDPMScheduler
|
20 |
+
from models.transformer_2d_flow import Transformer2DModel
|
21 |
+
from transformers import AutoFeatureExtractor, Wav2Vec2BertModel,HubertModel
|
22 |
+
# from tools.get_mulan import get_mulan
|
23 |
+
from third_party.wespeaker.extract_embd import XVECModel
|
24 |
+
# from libs.rvq2 import RVQEmbedding
|
25 |
+
from libs.rvq.descript_quantize3_4layer_freezelayer1 import ResidualVectorQuantize
|
26 |
+
|
27 |
+
from models_gpt.models.gpt2_rope2_time_new_correct_mask_noncasual_reflow import GPT2Model
|
28 |
+
from models_gpt.models.gpt2_config import GPT2Config
|
29 |
+
|
30 |
+
from torch.cuda.amp import autocast
|
31 |
+
|
32 |
+
|
33 |
+
from our_MERT_BESTRQ.test import load_model
|
34 |
+
|
35 |
+
class HubertModelWithFinalProj(HubertModel):
|
36 |
+
def __init__(self, config):
|
37 |
+
super().__init__(config)
|
38 |
+
|
39 |
+
# The final projection layer is only used for backward compatibility.
|
40 |
+
# Following https://github.com/auspicious3000/contentvec/issues/6
|
41 |
+
# Remove this layer is necessary to achieve the desired outcome.
|
42 |
+
print("hidden_size:",config.hidden_size)
|
43 |
+
print("classifier_proj_size:",config.classifier_proj_size)
|
44 |
+
self.final_proj = nn.Linear(config.hidden_size, config.classifier_proj_size)
|
45 |
+
|
46 |
+
|
47 |
+
class SampleProcessor(torch.nn.Module):
|
48 |
+
def project_sample(self, x: torch.Tensor):
|
49 |
+
"""Project the original sample to the 'space' where the diffusion will happen."""
|
50 |
+
"""Project back from diffusion space to the actual sample space."""
|
51 |
+
return z
|
52 |
+
|
53 |
+
class Feature1DProcessor(SampleProcessor):
|
54 |
+
def __init__(self, dim: int = 100, power_std = 1., \
|
55 |
+
num_samples: int = 100_000, cal_num_frames: int = 600):
|
56 |
+
super().__init__()
|
57 |
+
|
58 |
+
self.num_samples = num_samples
|
59 |
+
self.dim = dim
|
60 |
+
self.power_std = power_std
|
61 |
+
self.cal_num_frames = cal_num_frames
|
62 |
+
self.register_buffer('counts', torch.zeros(1))
|
63 |
+
self.register_buffer('sum_x', torch.zeros(dim))
|
64 |
+
self.register_buffer('sum_x2', torch.zeros(dim))
|
65 |
+
self.register_buffer('sum_target_x2', torch.zeros(dim))
|
66 |
+
self.counts: torch.Tensor
|
67 |
+
self.sum_x: torch.Tensor
|
68 |
+
self.sum_x2: torch.Tensor
|
69 |
+
|
70 |
+
@property
|
71 |
+
def mean(self):
|
72 |
+
mean = self.sum_x / self.counts
|
73 |
+
if(self.counts < 10):
|
74 |
+
mean = torch.zeros_like(mean)
|
75 |
+
return mean
|
76 |
+
|
77 |
+
@property
|
78 |
+
def std(self):
|
79 |
+
std = (self.sum_x2 / self.counts - self.mean**2).clamp(min=0).sqrt()
|
80 |
+
if(self.counts < 10):
|
81 |
+
std = torch.ones_like(std)
|
82 |
+
return std
|
83 |
+
|
84 |
+
@property
|
85 |
+
def target_std(self):
|
86 |
+
return 1
|
87 |
+
|
88 |
+
def project_sample(self, x: torch.Tensor):
|
89 |
+
assert x.dim() == 3
|
90 |
+
if self.counts.item() < self.num_samples:
|
91 |
+
self.counts += len(x)
|
92 |
+
self.sum_x += x[:,:,0:self.cal_num_frames].mean(dim=(2,)).sum(dim=0)
|
93 |
+
self.sum_x2 += x[:,:,0:self.cal_num_frames].pow(2).mean(dim=(2,)).sum(dim=0)
|
94 |
+
rescale = (self.target_std / self.std.clamp(min=1e-12)) ** self.power_std # same output size
|
95 |
+
x = (x - self.mean.view(1, -1, 1)) * rescale.view(1, -1, 1)
|
96 |
+
return x
|
97 |
+
|
98 |
+
def return_sample(self, x: torch.Tensor):
|
99 |
+
assert x.dim() == 3
|
100 |
+
rescale = (self.std / self.target_std) ** self.power_std
|
101 |
+
# print(rescale, self.mean)
|
102 |
+
x = x * rescale.view(1, -1, 1) + self.mean.view(1, -1, 1)
|
103 |
+
return x
|
104 |
+
|
105 |
+
def pad_or_tunc_tolen(prior_text_encoder_hidden_states, prior_text_mask, prior_prompt_embeds, len_size=77):
|
106 |
+
if(prior_text_encoder_hidden_states.shape[1]<len_size):
|
107 |
+
prior_text_encoder_hidden_states = torch.cat([prior_text_encoder_hidden_states, \
|
108 |
+
torch.zeros(prior_text_mask.shape[0], len_size-prior_text_mask.shape[1], \
|
109 |
+
prior_text_encoder_hidden_states.shape[2], device=prior_text_mask.device, \
|
110 |
+
dtype=prior_text_encoder_hidden_states.dtype)],1)
|
111 |
+
prior_text_mask = torch.cat([prior_text_mask, torch.zeros(prior_text_mask.shape[0], len_size-prior_text_mask.shape[1], device=prior_text_mask.device, dtype=prior_text_mask.dtype)],1)
|
112 |
+
else:
|
113 |
+
prior_text_encoder_hidden_states = prior_text_encoder_hidden_states[:,0:len_size]
|
114 |
+
prior_text_mask = prior_text_mask[:,0:len_size]
|
115 |
+
prior_text_encoder_hidden_states = prior_text_encoder_hidden_states.permute(0,2,1).contiguous()
|
116 |
+
return prior_text_encoder_hidden_states, prior_text_mask, prior_prompt_embeds
|
117 |
+
|
118 |
+
class BASECFM(torch.nn.Module, ABC):
|
119 |
+
def __init__(
|
120 |
+
self,
|
121 |
+
estimator,
|
122 |
+
mlp,
|
123 |
+
ssl_layer
|
124 |
+
):
|
125 |
+
super().__init__()
|
126 |
+
self.sigma_min = 1e-4
|
127 |
+
|
128 |
+
self.estimator = estimator
|
129 |
+
self.mlp = mlp
|
130 |
+
self.ssl_layer = ssl_layer
|
131 |
+
|
132 |
+
@torch.inference_mode()
|
133 |
+
def forward(self, mu, n_timesteps, temperature=1.0):
|
134 |
+
"""Forward diffusion
|
135 |
+
|
136 |
+
Args:
|
137 |
+
mu (torch.Tensor): output of encoder
|
138 |
+
shape: (batch_size, n_channels, mel_timesteps, n_feats)
|
139 |
+
n_timesteps (int): number of diffusion steps
|
140 |
+
temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
|
141 |
+
|
142 |
+
Returns:
|
143 |
+
sample: generated mel-spectrogram
|
144 |
+
shape: (batch_size, n_channels, mel_timesteps, n_feats)
|
145 |
+
"""
|
146 |
+
z = torch.randn_like(mu) * temperature
|
147 |
+
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device)
|
148 |
+
return self.solve_euler(z, t_span=t_span)
|
149 |
+
|
150 |
+
def solve_euler(self, x, latent_mask_input,incontext_x, incontext_length, t_span, mu,attention_mask, guidance_scale):
|
151 |
+
"""
|
152 |
+
Fixed euler solver for ODEs.
|
153 |
+
Args:
|
154 |
+
x (torch.Tensor): random noise
|
155 |
+
t_span (torch.Tensor): n_timesteps interpolated
|
156 |
+
shape: (n_timesteps + 1,)
|
157 |
+
mu (torch.Tensor): output of encoder
|
158 |
+
shape: (batch_size, n_channels, mel_timesteps, n_feats)
|
159 |
+
"""
|
160 |
+
t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
|
161 |
+
noise = x.clone()
|
162 |
+
|
163 |
+
# I am storing this because I can later plot it by putting a debugger here and saving it to a file
|
164 |
+
# Or in future might add like a return_all_steps flag
|
165 |
+
sol = []
|
166 |
+
|
167 |
+
for step in tqdm(range(1, len(t_span))):
|
168 |
+
print("incontext_x.shape:",incontext_x.shape)
|
169 |
+
print("noise.shape:",noise.shape)
|
170 |
+
print("t.shape:",t.shape)
|
171 |
+
x[:,0:incontext_length,:] = (1 - (1 - self.sigma_min) * t) * noise[:,0:incontext_length,:] + t * incontext_x[:,0:incontext_length,:]
|
172 |
+
if(guidance_scale > 1.0):
|
173 |
+
|
174 |
+
model_input = torch.cat([ \
|
175 |
+
torch.cat([latent_mask_input, latent_mask_input], 0), \
|
176 |
+
torch.cat([incontext_x, incontext_x], 0), \
|
177 |
+
torch.cat([torch.zeros_like(mu), mu], 0), \
|
178 |
+
torch.cat([x, x], 0), \
|
179 |
+
], 2)
|
180 |
+
timestep=t.unsqueeze(-1).repeat(2)
|
181 |
+
|
182 |
+
dphi_dt = self.estimator(inputs_embeds=model_input, attention_mask=attention_mask,time_step=timestep).last_hidden_state
|
183 |
+
dphi_dt_uncond, dhpi_dt_cond = dphi_dt.chunk(2,0)
|
184 |
+
dphi_dt = dphi_dt_uncond + guidance_scale * (dhpi_dt_cond - dphi_dt_uncond)
|
185 |
+
else:
|
186 |
+
model_input = torch.cat([latent_mask_input, incontext_x, mu, x], 2)
|
187 |
+
timestep=t.unsqueeze(-1)
|
188 |
+
dphi_dt = self.estimator(inputs_embeds=model_input, attention_mask=attention_mask,time_step=timestep).last_hidden_state
|
189 |
+
|
190 |
+
dphi_dt = dphi_dt[: ,:, -x.shape[2]:]
|
191 |
+
print("dphi_dt.shape:",dphi_dt.shape)
|
192 |
+
print("x.shape:",x.shape)
|
193 |
+
|
194 |
+
x = x + dt * dphi_dt
|
195 |
+
t = t + dt
|
196 |
+
sol.append(x)
|
197 |
+
if step < len(t_span) - 1:
|
198 |
+
dt = t_span[step + 1] - t
|
199 |
+
|
200 |
+
return sol[-1]
|
201 |
+
|
202 |
+
def projection_loss(self,hidden_proj, bestrq_emb):
|
203 |
+
bsz = hidden_proj.shape[0]
|
204 |
+
|
205 |
+
hidden_proj_normalized = F.normalize(hidden_proj, dim=-1)
|
206 |
+
bestrq_emb_normalized = F.normalize(bestrq_emb, dim=-1)
|
207 |
+
|
208 |
+
proj_loss = -(hidden_proj_normalized * bestrq_emb_normalized).sum(dim=-1)
|
209 |
+
proj_loss = 1+proj_loss.mean()
|
210 |
+
|
211 |
+
return proj_loss
|
212 |
+
|
213 |
+
def compute_loss(self, x1, mu, latent_masks,attention_mask,wav2vec_embeds, validation_mode=False):
|
214 |
+
"""Computes diffusion loss
|
215 |
+
|
216 |
+
Args:
|
217 |
+
x1 (torch.Tensor): Target
|
218 |
+
shape: (batch_size, n_channels, mel_timesteps, n_feats)
|
219 |
+
mu (torch.Tensor): output of encoder
|
220 |
+
shape: (batch_size, n_channels, mel_timesteps, n_feats)
|
221 |
+
|
222 |
+
Returns:
|
223 |
+
loss: conditional flow matching loss
|
224 |
+
y: conditional flow
|
225 |
+
shape: (batch_size, n_channels, mel_timesteps, n_feats)
|
226 |
+
"""
|
227 |
+
b = mu[0].shape[0]
|
228 |
+
len_x = x1.shape[2]
|
229 |
+
# random timestep
|
230 |
+
if(validation_mode):
|
231 |
+
t = torch.ones([b, 1, 1], device=mu[0].device, dtype=mu[0].dtype) * 0.5
|
232 |
+
else:
|
233 |
+
t = torch.rand([b, 1, 1], device=mu[0].device, dtype=mu[0].dtype)
|
234 |
+
# sample noise p(x_0)
|
235 |
+
z = torch.randn_like(x1)
|
236 |
+
|
237 |
+
y = (1 - (1 - self.sigma_min) * t) * z + t * x1
|
238 |
+
u = x1 - (1 - self.sigma_min) * z
|
239 |
+
# print("y.shape:",y.shape)
|
240 |
+
#self.unet(inputs_embeds=model_input, attention_mask=attention_mask,encoder_hidden_states=text_embedding,encoder_attention_mask=txt_attn_mask,time_step=timesteps).last_hidden_state
|
241 |
+
model_input = torch.cat([*mu,y], 2)
|
242 |
+
t=t.squeeze(-1).squeeze(-1)
|
243 |
+
# print("model_input.shape:",model_input.shape)
|
244 |
+
# print("attention_mask.shape:",attention_mask.shape)
|
245 |
+
out = self.estimator(inputs_embeds=model_input, attention_mask=attention_mask,time_step=t,output_hidden_states=True)
|
246 |
+
hidden_layer = out.hidden_states[self.ssl_layer]
|
247 |
+
hidden_proj = self.mlp(hidden_layer)
|
248 |
+
# print("hidden_proj.shape:",hidden_proj.shape)
|
249 |
+
# print("mert_emb.shape:",mert_emb.shape)
|
250 |
+
# exit()
|
251 |
+
|
252 |
+
|
253 |
+
out = out.last_hidden_state
|
254 |
+
|
255 |
+
out=out[:,:,-len_x:]
|
256 |
+
# out=self.proj_out(out)
|
257 |
+
|
258 |
+
weight = (latent_masks > 1.5).unsqueeze(-1).repeat(1, 1, out.shape[-1]).float() + (latent_masks < 0.5).unsqueeze(-1).repeat(1, 1, out.shape[-1]).float() * 0.01
|
259 |
+
# print("out.shape",out.shape)
|
260 |
+
# print("u.shape",u.shape)
|
261 |
+
loss_re = F.mse_loss(out * weight, u * weight, reduction="sum") / weight.sum()
|
262 |
+
# print("hidden_proj.shape:",hidden_proj.shape)
|
263 |
+
# print("wav2vec_embeds.shape:",wav2vec_embeds.shape)
|
264 |
+
loss_cos = self.projection_loss(hidden_proj, wav2vec_embeds)
|
265 |
+
loss = loss_re + loss_cos * 0.5
|
266 |
+
# print("loss_cos:",loss_cos,loss_cos.device)
|
267 |
+
print("loss:",loss,loss.device)
|
268 |
+
# exit()
|
269 |
+
return loss, loss_re, loss_cos
|
270 |
+
|
271 |
+
class PromptCondAudioDiffusion(nn.Module):
|
272 |
+
def __init__(
|
273 |
+
self,
|
274 |
+
num_channels,
|
275 |
+
unet_model_name=None,
|
276 |
+
unet_model_config_path=None,
|
277 |
+
snr_gamma=None,
|
278 |
+
hubert_layer=None,
|
279 |
+
ssl_layer=None,
|
280 |
+
uncondition=True,
|
281 |
+
out_paint=False,
|
282 |
+
):
|
283 |
+
super().__init__()
|
284 |
+
|
285 |
+
assert unet_model_name is not None or unet_model_config_path is not None, "Either UNet pretrain model name or a config file path is required"
|
286 |
+
|
287 |
+
self.unet_model_name = unet_model_name
|
288 |
+
self.unet_model_config_path = unet_model_config_path
|
289 |
+
self.snr_gamma = snr_gamma
|
290 |
+
self.uncondition = uncondition
|
291 |
+
self.num_channels = num_channels
|
292 |
+
self.hubert_layer = hubert_layer
|
293 |
+
self.ssl_layer = ssl_layer
|
294 |
+
|
295 |
+
# https://huggingface.co/docs/diffusers/v0.14.0/en/api/schedulers/overview
|
296 |
+
self.normfeat = Feature1DProcessor(dim=64)
|
297 |
+
|
298 |
+
self.sample_rate = 48000
|
299 |
+
self.num_samples_perseg = self.sample_rate * 20 // 1000
|
300 |
+
self.rsp48toclap = torchaudio.transforms.Resample(48000, 24000)
|
301 |
+
self.rsq48towav2vec = torchaudio.transforms.Resample(48000, 16000)
|
302 |
+
# self.wav2vec = Wav2Vec2BertModel.from_pretrained("facebook/w2v-bert-2.0", trust_remote_code=True)
|
303 |
+
# self.wav2vec_processor = AutoFeatureExtractor.from_pretrained("facebook/w2v-bert-2.0", trust_remote_code=True)
|
304 |
+
self.bestrq = load_model(
|
305 |
+
model_dir='path/to/our-MERT/mert_fairseq',
|
306 |
+
checkpoint_dir='checkpoint-120000.pt',
|
307 |
+
)
|
308 |
+
self.rsq48tobestrq = torchaudio.transforms.Resample(48000, 24000)
|
309 |
+
self.rsq48tohubert = torchaudio.transforms.Resample(48000, 16000)
|
310 |
+
for v in self.bestrq.parameters():v.requires_grad = False
|
311 |
+
self.rvq_bestrq_emb = ResidualVectorQuantize(input_dim = 1024, n_codebooks = 4, codebook_size = 16_384, codebook_dim = 32, quantizer_dropout = 0.0, stale_tolerance=200)
|
312 |
+
# for v in self.rvq_bestrq_emb.parameters():
|
313 |
+
# print(v)
|
314 |
+
freeze_parameters='quantizers.0'
|
315 |
+
for name, param in self.rvq_bestrq_emb.named_parameters():
|
316 |
+
if freeze_parameters in name:
|
317 |
+
param.requires_grad = False
|
318 |
+
print("Freezing RVQ parameters:", name)
|
319 |
+
self.hubert = HubertModelWithFinalProj.from_pretrained("huggingface_cache/models--lengyue233--content-vec-best/snapshots/c0b9ba13db21beaa4053faae94c102ebe326fd68")
|
320 |
+
for v in self.hubert.parameters():v.requires_grad = False
|
321 |
+
self.zero_cond_embedding1 = nn.Parameter(torch.randn(32*32,))
|
322 |
+
# self.xvecmodel = XVECModel()
|
323 |
+
config = GPT2Config(n_positions=1000,n_layer=39,n_head=30,n_embd=1200)
|
324 |
+
unet = GPT2Model(config)
|
325 |
+
mlp = nn.Sequential(
|
326 |
+
nn.Linear(1200, 1024),
|
327 |
+
nn.SiLU(),
|
328 |
+
nn.Linear(1024, 1024),
|
329 |
+
nn.SiLU(),
|
330 |
+
nn.Linear(1024, 768)
|
331 |
+
)
|
332 |
+
self.set_from = "random"
|
333 |
+
self.cfm_wrapper = BASECFM(unet, mlp,self.ssl_layer)
|
334 |
+
self.mask_emb = torch.nn.Embedding(3, 48)
|
335 |
+
print("Transformer initialized from pretrain.")
|
336 |
+
torch.cuda.empty_cache()
|
337 |
+
# self.unet.set_attn_processor(AttnProcessor2_0())
|
338 |
+
# self.unet.set_use_memory_efficient_attention_xformers(True)
|
339 |
+
|
340 |
+
# self.start_embedding = nn.Parameter(torch.randn(1,1024))
|
341 |
+
# self.end_embedding = nn.Parameter(torch.randn(1,1024))
|
342 |
+
|
343 |
+
def compute_snr(self, timesteps):
|
344 |
+
"""
|
345 |
+
Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
|
346 |
+
"""
|
347 |
+
alphas_cumprod = self.noise_scheduler.alphas_cumprod
|
348 |
+
sqrt_alphas_cumprod = alphas_cumprod**0.5
|
349 |
+
sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
|
350 |
+
|
351 |
+
# Expand the tensors.
|
352 |
+
# Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026
|
353 |
+
sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
|
354 |
+
while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape):
|
355 |
+
sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
|
356 |
+
alpha = sqrt_alphas_cumprod.expand(timesteps.shape)
|
357 |
+
|
358 |
+
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
|
359 |
+
while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape):
|
360 |
+
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None]
|
361 |
+
sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape)
|
362 |
+
|
363 |
+
# Compute SNR.
|
364 |
+
snr = (alpha / sigma) ** 2
|
365 |
+
return snr
|
366 |
+
|
367 |
+
def preprocess_audio(self, input_audios, threshold=0.9):
|
368 |
+
assert len(input_audios.shape) == 2, input_audios.shape
|
369 |
+
norm_value = torch.ones_like(input_audios[:,0])
|
370 |
+
max_volume = input_audios.abs().max(dim=-1)[0]
|
371 |
+
norm_value[max_volume>threshold] = max_volume[max_volume>threshold] / threshold
|
372 |
+
return input_audios/norm_value.unsqueeze(-1)
|
373 |
+
|
374 |
+
def extract_wav2vec_embeds(self, input_audios,output_len):
|
375 |
+
wav2vec_stride = 2
|
376 |
+
|
377 |
+
wav2vec_embeds = self.hubert(self.rsq48tohubert(input_audios), output_hidden_states=True).hidden_states # 1, 4096, 1024
|
378 |
+
# print(wav2vec_embeds)
|
379 |
+
# print("audio.shape:",input_audios.shape)
|
380 |
+
wav2vec_embeds_last=wav2vec_embeds[self.hubert_layer]
|
381 |
+
# print("wav2vec_embeds_last.shape:",wav2vec_embeds_last.shape)
|
382 |
+
wav2vec_embeds_last=torch.nn.functional.interpolate(wav2vec_embeds_last.permute(0, 2, 1), size=output_len, mode='linear', align_corners=False).permute(0, 2, 1)
|
383 |
+
return wav2vec_embeds_last
|
384 |
+
|
385 |
+
def extract_mert_embeds(self, input_audios):
|
386 |
+
prompt_stride = 3
|
387 |
+
inputs = self.clap_embd_extractor.mulan.audio.processor(self.rsp48toclap(input_audios), sampling_rate=self.clap_embd_extractor.mulan.audio.sr, return_tensors="pt")
|
388 |
+
input_values = inputs['input_values'].squeeze(0).to(input_audios.device, dtype = input_audios.dtype)
|
389 |
+
prompt_embeds = self.clap_embd_extractor.mulan.audio.model(input_values, output_hidden_states=True).hidden_states # batch_size, Time steps, 1024
|
390 |
+
mert_emb= prompt_embeds[-1]
|
391 |
+
mert_emb = torch.nn.functional.interpolate(mert_emb.permute(0, 2, 1), size=500, mode='linear', align_corners=False).permute(0, 2, 1)
|
392 |
+
|
393 |
+
return mert_emb
|
394 |
+
|
395 |
+
def extract_bestrq_embeds(self, input_audio_0,input_audio_1,layer):
|
396 |
+
self.bestrq.eval()
|
397 |
+
# print("audio shape:",input_audio_0.shape)
|
398 |
+
input_wav_mean = (input_audio_0 + input_audio_1) / 2.0
|
399 |
+
# print("input_wav_mean.shape:",input_wav_mean.shape)
|
400 |
+
# input_wav_mean = torch.randn(2,1720320*2).to(input_audio_0.device)
|
401 |
+
input_wav_mean = self.bestrq(self.rsq48tobestrq(input_wav_mean), features_only = True)
|
402 |
+
layer_results = input_wav_mean['layer_results']
|
403 |
+
# print("layer_results.shape:",layer_results[layer].shape)
|
404 |
+
bestrq_emb = layer_results[layer]
|
405 |
+
bestrq_emb = bestrq_emb.permute(0,2,1).contiguous()
|
406 |
+
#[b,t,1024] t=t/960
|
407 |
+
#35.84s->batch,896,1024
|
408 |
+
return bestrq_emb
|
409 |
+
|
410 |
+
|
411 |
+
def extract_spk_embeds(self, input_audios):
|
412 |
+
spk_embeds = self.xvecmodel(self.rsq48towav2vec(input_audios))
|
413 |
+
spk_embeds = self.spk_linear(spk_embeds).reshape(spk_embeds.shape[0], 16, 1, 32)
|
414 |
+
return spk_embeds
|
415 |
+
|
416 |
+
def extract_lyric_feats(self, lyric):
|
417 |
+
with torch.no_grad():
|
418 |
+
try:
|
419 |
+
text_encoder_hidden_states, text_mask, text_prompt_embeds = self.clap_embd_extractor(texts = lyric, return_one=False)
|
420 |
+
except:
|
421 |
+
text_encoder_hidden_states, text_mask, text_prompt_embeds = self.clap_embd_extractor(texts = [""] * len(lyric), return_one=False)
|
422 |
+
text_encoder_hidden_states = text_encoder_hidden_states.to(self.device)
|
423 |
+
text_mask = text_mask.to(self.device)
|
424 |
+
text_encoder_hidden_states, text_mask, text_prompt_embeds = \
|
425 |
+
pad_or_tunc_tolen(text_encoder_hidden_states, text_mask, text_prompt_embeds)
|
426 |
+
text_encoder_hidden_states = text_encoder_hidden_states.permute(0,2,1).contiguous()
|
427 |
+
return text_encoder_hidden_states, text_mask
|
428 |
+
|
429 |
+
def extract_energy_bar(self, input_audios):
|
430 |
+
if(input_audios.shape[-1] % self.num_samples_perseg > 0):
|
431 |
+
energy_bar = input_audios[:,:-1 * (input_audios.shape[-1] % self.num_samples_perseg)].reshape(input_audios.shape[0],-1,self.num_samples_perseg)
|
432 |
+
else:
|
433 |
+
energy_bar = input_audios.reshape(input_audios.shape[0],-1,self.num_samples_perseg)
|
434 |
+
energy_bar = (energy_bar.pow(2.0).mean(-1).sqrt() + 1e-6).log10() * 20 # B T
|
435 |
+
energy_bar = (energy_bar / 2.0 + 16).clamp(0,16).int()
|
436 |
+
energy_embedding = self.energy_embedding(energy_bar)
|
437 |
+
energy_embedding = energy_embedding.view(energy_embedding.shape[0], energy_embedding.shape[1] // 2, 2, 32).reshape(energy_embedding.shape[0], energy_embedding.shape[1] // 2, 64).permute(0,2,1) # b 128 t
|
438 |
+
return energy_embedding
|
439 |
+
|
440 |
+
def forward(self, input_audios, lyric, latents, latent_masks, validation_mode=False, \
|
441 |
+
additional_feats = ['spk', 'lyric'], \
|
442 |
+
train_rvq=True, train_ssl=False,layer=5):
|
443 |
+
if not hasattr(self,"device"):
|
444 |
+
self.device = input_audios.device
|
445 |
+
if not hasattr(self,"dtype"):
|
446 |
+
self.dtype = input_audios.dtype
|
447 |
+
device = self.device
|
448 |
+
input_audio_0 = input_audios[:,0,:]
|
449 |
+
input_audio_1 = input_audios[:,1,:]
|
450 |
+
input_audio_0 = self.preprocess_audio(input_audio_0)
|
451 |
+
input_audio_1 = self.preprocess_audio(input_audio_1)
|
452 |
+
input_audios_wav2vec = (input_audio_0 + input_audio_1) / 2.0
|
453 |
+
# energy_embedding = self.extract_energy_bar(input_audios)
|
454 |
+
# print("energy_embedding.shape:",energy_embedding.shape)
|
455 |
+
# with autocast(enabled=False):
|
456 |
+
if(train_ssl):
|
457 |
+
self.wav2vec.train()
|
458 |
+
wav2vec_embeds = self.extract_wav2vec_embeds(input_audios)
|
459 |
+
self.clap_embd_extractor.train()
|
460 |
+
prompt_embeds = self.extract_mert_embeds(input_audios)
|
461 |
+
if('spk' in additional_feats):
|
462 |
+
self.xvecmodel.train()
|
463 |
+
spk_embeds = self.extract_spk_embeds(input_audios).repeat(1,1,prompt_embeds.shape[-1]//2,1)
|
464 |
+
else:
|
465 |
+
with torch.no_grad():
|
466 |
+
with autocast(enabled=False):
|
467 |
+
bestrq_emb = self.extract_bestrq_embeds(input_audio_0,input_audio_1,layer)
|
468 |
+
# mert_emb = self.extract_mert_embeds(input_audios_mert)
|
469 |
+
|
470 |
+
wav2vec_embeds = self.extract_wav2vec_embeds(input_audios_wav2vec,bestrq_emb.shape[2])
|
471 |
+
|
472 |
+
bestrq_emb = bestrq_emb.detach()
|
473 |
+
if('lyric' in additional_feats):
|
474 |
+
text_encoder_hidden_states, text_mask = self.extract_lyric_feats(lyric)
|
475 |
+
else:
|
476 |
+
text_encoder_hidden_states, text_mask = None, None
|
477 |
+
|
478 |
+
|
479 |
+
if(train_rvq):
|
480 |
+
random_num=random.random()
|
481 |
+
if(random_num<0.6):
|
482 |
+
rvq_layer = 1
|
483 |
+
elif(random_num<0.8):
|
484 |
+
rvq_layer = 2
|
485 |
+
else:
|
486 |
+
rvq_layer = 4
|
487 |
+
quantized_bestrq_emb, _, _, commitment_loss_bestrq_emb, codebook_loss_bestrq_emb,_ = self.rvq_bestrq_emb(bestrq_emb,n_quantizers=rvq_layer) # b,d,t
|
488 |
+
else:
|
489 |
+
bestrq_emb = bestrq_emb.float()
|
490 |
+
self.rvq_bestrq_emb.eval()
|
491 |
+
# with autocast(enabled=False):
|
492 |
+
quantized_bestrq_emb, _, _, commitment_loss_bestrq_emb, codebook_loss_bestrq_emb,_ = self.rvq_bestrq_emb(bestrq_emb) # b,d,t
|
493 |
+
commitment_loss_bestrq_emb = commitment_loss_bestrq_emb.detach()
|
494 |
+
codebook_loss_bestrq_emb = codebook_loss_bestrq_emb.detach()
|
495 |
+
quantized_bestrq_emb = quantized_bestrq_emb.detach()
|
496 |
+
|
497 |
+
commitment_loss = commitment_loss_bestrq_emb
|
498 |
+
codebook_loss = codebook_loss_bestrq_emb
|
499 |
+
|
500 |
+
|
501 |
+
alpha=1
|
502 |
+
quantized_bestrq_emb = quantized_bestrq_emb * alpha + bestrq_emb * (1-alpha)
|
503 |
+
|
504 |
+
# print("quantized_bestrq_emb.shape:",quantized_bestrq_emb.shape)
|
505 |
+
# print("latent_masks.shape:",latent_masks.shape)
|
506 |
+
# quantized_bestrq_emb = torch.nn.functional.interpolate(quantized_bestrq_emb, size=(int(quantized_bestrq_emb.shape[-1]/999*937),), mode='linear', align_corners=True)
|
507 |
+
|
508 |
+
|
509 |
+
|
510 |
+
scenario = np.random.choice(['start_seg', 'other_seg'])
|
511 |
+
if(scenario == 'other_seg'):
|
512 |
+
for binx in range(input_audios.shape[0]):
|
513 |
+
# latent_masks[binx,0:64] = 1
|
514 |
+
latent_masks[binx,0:random.randint(64,128)] = 1
|
515 |
+
quantized_bestrq_emb = quantized_bestrq_emb.permute(0,2,1).contiguous()
|
516 |
+
# print("quantized_bestrq_emb.shape:",quantized_bestrq_emb.shape)
|
517 |
+
# print("quantized_bestrq_emb1.shape:",quantized_bestrq_emb.shape)
|
518 |
+
# print("latent_masks.shape:",latent_masks.shape)
|
519 |
+
quantized_bestrq_emb = (latent_masks > 0.5).unsqueeze(-1) * quantized_bestrq_emb \
|
520 |
+
+ (latent_masks < 0.5).unsqueeze(-1) * self.zero_cond_embedding1.reshape(1,1,1024)
|
521 |
+
|
522 |
+
|
523 |
+
|
524 |
+
|
525 |
+
if self.uncondition:
|
526 |
+
mask_indices = [k for k in range(quantized_bestrq_emb.shape[0]) if random.random() < 0.1]
|
527 |
+
if len(mask_indices) > 0:
|
528 |
+
quantized_bestrq_emb[mask_indices] = 0
|
529 |
+
# print("latents.shape:",latents.shape)
|
530 |
+
latents = latents.permute(0,2,1).contiguous()
|
531 |
+
latents = self.normfeat.project_sample(latents)
|
532 |
+
latents = latents.permute(0,2,1).contiguous()
|
533 |
+
incontext_latents = latents * ((latent_masks > 0.5) * (latent_masks < 1.5)).unsqueeze(-1).float()
|
534 |
+
attention_mask=(latent_masks > 0.5)
|
535 |
+
B, L = attention_mask.size()
|
536 |
+
attention_mask = attention_mask.view(B, 1, L)
|
537 |
+
attention_mask = attention_mask * attention_mask.transpose(-1, -2)
|
538 |
+
attention_mask = attention_mask.unsqueeze(1)
|
539 |
+
# print("incontext_latents.shape:",incontext_latents.shape)
|
540 |
+
# print("quantized_bestrq_emb.shape:",quantized_bestrq_emb.shape)
|
541 |
+
latent_mask_input = self.mask_emb(latent_masks)
|
542 |
+
#64+48+64+1024
|
543 |
+
loss,loss_re, loss_cos = self.cfm_wrapper.compute_loss(latents, [latent_mask_input,incontext_latents, quantized_bestrq_emb], latent_masks,attention_mask,wav2vec_embeds, validation_mode=validation_mode)
|
544 |
+
return loss,loss_re, loss_cos, commitment_loss.mean(), codebook_loss.mean()
|
545 |
+
|
546 |
+
def init_device_dtype(self, device, dtype):
|
547 |
+
self.device = device
|
548 |
+
self.dtype = dtype
|
549 |
+
|
550 |
+
@torch.no_grad()
|
551 |
+
def fetch_codes(self, input_audios, additional_feats,layer,rvq_num=1):
|
552 |
+
input_audio_0 = input_audios[[0],:]
|
553 |
+
input_audio_1 = input_audios[[1],:]
|
554 |
+
input_audio_0 = self.preprocess_audio(input_audio_0)
|
555 |
+
input_audio_1 = self.preprocess_audio(input_audio_1)
|
556 |
+
|
557 |
+
self.bestrq.eval()
|
558 |
+
|
559 |
+
# bestrq_middle,bestrq_last = self.extract_bestrq_embeds(input_audios)
|
560 |
+
# bestrq_middle = bestrq_middle.detach()
|
561 |
+
# bestrq_last = bestrq_last.detach()
|
562 |
+
bestrq_emb = self.extract_bestrq_embeds(input_audio_0,input_audio_1,layer)
|
563 |
+
bestrq_emb = bestrq_emb.detach()
|
564 |
+
|
565 |
+
# self.rvq_bestrq_middle.eval()
|
566 |
+
# quantized_bestrq_middle, codes_bestrq_middle, *_ = self.rvq_bestrq_middle(bestrq_middle) # b,d,t
|
567 |
+
# self.rvq_bestrq_last.eval()
|
568 |
+
# quantized_bestrq_last, codes_bestrq_last, *_ = self.rvq_bestrq_last(bestrq_last) # b,d,t
|
569 |
+
|
570 |
+
self.rvq_bestrq_emb.eval()
|
571 |
+
quantized_bestrq_emb, codes_bestrq_emb, *_ = self.rvq_bestrq_emb(bestrq_emb)
|
572 |
+
codes_bestrq_emb = codes_bestrq_emb[:,:rvq_num,:]
|
573 |
+
# print("codes_bestrq_emb.shape:",codes_bestrq_emb.shape)
|
574 |
+
# exit()
|
575 |
+
|
576 |
+
|
577 |
+
if('spk' in additional_feats):
|
578 |
+
self.xvecmodel.eval()
|
579 |
+
spk_embeds = self.extract_spk_embeds(input_audios)
|
580 |
+
else:
|
581 |
+
spk_embeds = None
|
582 |
+
|
583 |
+
# return [codes_prompt, codes_wav2vec], [prompt_embeds, wav2vec_embeds], spk_embeds
|
584 |
+
# return [codes_prompt_7, codes_prompt_13, codes_prompt_20, codes_wav2vec_half, codes_wav2vec_last], [prompt_embeds_7, prompt_embeds_13, prompt_embeds_20, wav2vec_embeds_half, wav2vec_embeds_last], spk_embeds
|
585 |
+
# return [codes_bestrq_middle, codes_bestrq_last], [bestrq_middle, bestrq_last], spk_embeds
|
586 |
+
return [codes_bestrq_emb], [bestrq_emb], spk_embeds
|
587 |
+
# return [codes_prompt_13, codes_wav2vec_last], [prompt_embeds_13, wav2vec_embeds_last], spk_embeds
|
588 |
+
|
589 |
+
@torch.no_grad()
|
590 |
+
def fetch_codes_batch(self, input_audios, additional_feats,layer,rvq_num=1):
|
591 |
+
input_audio_0 = input_audios[:,0,:]
|
592 |
+
input_audio_1 = input_audios[:,1,:]
|
593 |
+
input_audio_0 = self.preprocess_audio(input_audio_0)
|
594 |
+
input_audio_1 = self.preprocess_audio(input_audio_1)
|
595 |
+
|
596 |
+
self.bestrq.eval()
|
597 |
+
|
598 |
+
# bestrq_middle,bestrq_last = self.extract_bestrq_embeds(input_audios)
|
599 |
+
# bestrq_middle = bestrq_middle.detach()
|
600 |
+
# bestrq_last = bestrq_last.detach()
|
601 |
+
bestrq_emb = self.extract_bestrq_embeds(input_audio_0,input_audio_1,layer)
|
602 |
+
bestrq_emb = bestrq_emb.detach()
|
603 |
+
|
604 |
+
# self.rvq_bestrq_middle.eval()
|
605 |
+
# quantized_bestrq_middle, codes_bestrq_middle, *_ = self.rvq_bestrq_middle(bestrq_middle) # b,d,t
|
606 |
+
# self.rvq_bestrq_last.eval()
|
607 |
+
# quantized_bestrq_last, codes_bestrq_last, *_ = self.rvq_bestrq_last(bestrq_last) # b,d,t
|
608 |
+
|
609 |
+
self.rvq_bestrq_emb.eval()
|
610 |
+
quantized_bestrq_emb, codes_bestrq_emb, *_ = self.rvq_bestrq_emb(bestrq_emb)
|
611 |
+
# print("codes_bestrq_emb.shape:",codes_bestrq_emb.shape)
|
612 |
+
codes_bestrq_emb = codes_bestrq_emb[:,:rvq_num,:]
|
613 |
+
# print("codes_bestrq_emb.shape:",codes_bestrq_emb.shape)
|
614 |
+
# exit()
|
615 |
+
|
616 |
+
|
617 |
+
if('spk' in additional_feats):
|
618 |
+
self.xvecmodel.eval()
|
619 |
+
spk_embeds = self.extract_spk_embeds(input_audios)
|
620 |
+
else:
|
621 |
+
spk_embeds = None
|
622 |
+
|
623 |
+
# return [codes_prompt, codes_wav2vec], [prompt_embeds, wav2vec_embeds], spk_embeds
|
624 |
+
# return [codes_prompt_7, codes_prompt_13, codes_prompt_20, codes_wav2vec_half, codes_wav2vec_last], [prompt_embeds_7, prompt_embeds_13, prompt_embeds_20, wav2vec_embeds_half, wav2vec_embeds_last], spk_embeds
|
625 |
+
# return [codes_bestrq_middle, codes_bestrq_last], [bestrq_middle, bestrq_last], spk_embeds
|
626 |
+
return [codes_bestrq_emb], [bestrq_emb], spk_embeds
|
627 |
+
# return [codes_prompt_13, codes_wav2vec_last], [prompt_embeds_13, wav2vec_embeds_last], spk_embeds
|
628 |
+
|
629 |
+
@torch.no_grad()
|
630 |
+
def fetch_codes_batch_ds(self, input_audios, additional_feats, layer, rvq_num=1, ds=250):
|
631 |
+
input_audio_0 = input_audios[:,0,:]
|
632 |
+
input_audio_1 = input_audios[:,1,:]
|
633 |
+
input_audio_0 = self.preprocess_audio(input_audio_0)
|
634 |
+
input_audio_1 = self.preprocess_audio(input_audio_1)
|
635 |
+
|
636 |
+
self.bestrq.eval()
|
637 |
+
|
638 |
+
# bestrq_middle,bestrq_last = self.extract_bestrq_embeds(input_audios)
|
639 |
+
# bestrq_middle = bestrq_middle.detach()
|
640 |
+
# bestrq_last = bestrq_last.detach()
|
641 |
+
bestrq_emb = self.extract_bestrq_embeds(input_audio_0,input_audio_1,layer)
|
642 |
+
bestrq_emb = bestrq_emb.detach()
|
643 |
+
|
644 |
+
# self.rvq_bestrq_middle.eval()
|
645 |
+
# quantized_bestrq_middle, codes_bestrq_middle, *_ = self.rvq_bestrq_middle(bestrq_middle) # b,d,t
|
646 |
+
# self.rvq_bestrq_last.eval()
|
647 |
+
# quantized_bestrq_last, codes_bestrq_last, *_ = self.rvq_bestrq_last(bestrq_last) # b,d,t
|
648 |
+
|
649 |
+
self.rvq_bestrq_emb.eval()
|
650 |
+
bestrq_emb = torch.nn.functional.avg_pool1d(bestrq_emb, kernel_size=ds, stride=ds)
|
651 |
+
quantized_bestrq_emb, codes_bestrq_emb, *_ = self.rvq_bestrq_emb(bestrq_emb)
|
652 |
+
# print("codes_bestrq_emb.shape:",codes_bestrq_emb.shape)
|
653 |
+
codes_bestrq_emb = codes_bestrq_emb[:,:rvq_num,:]
|
654 |
+
# print("codes_bestrq_emb.shape:",codes_bestrq_emb.shape)
|
655 |
+
# exit()
|
656 |
+
|
657 |
+
|
658 |
+
if('spk' in additional_feats):
|
659 |
+
self.xvecmodel.eval()
|
660 |
+
spk_embeds = self.extract_spk_embeds(input_audios)
|
661 |
+
else:
|
662 |
+
spk_embeds = None
|
663 |
+
|
664 |
+
# return [codes_prompt, codes_wav2vec], [prompt_embeds, wav2vec_embeds], spk_embeds
|
665 |
+
# return [codes_prompt_7, codes_prompt_13, codes_prompt_20, codes_wav2vec_half, codes_wav2vec_last], [prompt_embeds_7, prompt_embeds_13, prompt_embeds_20, wav2vec_embeds_half, wav2vec_embeds_last], spk_embeds
|
666 |
+
# return [codes_bestrq_middle, codes_bestrq_last], [bestrq_middle, bestrq_last], spk_embeds
|
667 |
+
return [codes_bestrq_emb], [bestrq_emb], spk_embeds
|
668 |
+
# return [codes_prompt_13, codes_wav2vec_last], [prompt_embeds_13, wav2vec_embeds_last], spk_embeds
|
669 |
+
|
670 |
+
@torch.no_grad()
|
671 |
+
def inference_codes(self, codes, spk_embeds, true_latents, latent_length, additional_feats, incontext_length=127,
|
672 |
+
guidance_scale=2, num_steps=20,
|
673 |
+
disable_progress=True, scenario='start_seg'):
|
674 |
+
classifier_free_guidance = guidance_scale > 1.0
|
675 |
+
device = self.device
|
676 |
+
dtype = self.dtype
|
677 |
+
# codes_bestrq_middle, codes_bestrq_last = codes
|
678 |
+
codes_bestrq_emb = codes[0]
|
679 |
+
|
680 |
+
|
681 |
+
batch_size = codes_bestrq_emb.shape[0]
|
682 |
+
|
683 |
+
|
684 |
+
quantized_bestrq_emb,_,_=self.rvq_bestrq_emb.from_codes(codes_bestrq_emb)
|
685 |
+
# quantized_bestrq_emb = torch.nn.functional.interpolate(quantized_bestrq_emb, size=(int(quantized_bestrq_emb.shape[-1]/999*937),), mode='linear', align_corners=True)
|
686 |
+
quantized_bestrq_emb = quantized_bestrq_emb.permute(0,2,1).contiguous()
|
687 |
+
print("quantized_bestrq_emb.shape:",quantized_bestrq_emb.shape)
|
688 |
+
# quantized_bestrq_emb = torch.nn.functional.interpolate(quantized_bestrq_emb, size=(int(quantized_bestrq_emb.shape[-1]/999*937),), mode='linear', align_corners=True)
|
689 |
+
|
690 |
+
|
691 |
+
|
692 |
+
|
693 |
+
if('spk' in additional_feats):
|
694 |
+
spk_embeds = spk_embeds.repeat(1,1,quantized_bestrq_emb.shape[-2],1).detach()
|
695 |
+
|
696 |
+
num_frames = quantized_bestrq_emb.shape[1]
|
697 |
+
|
698 |
+
num_channels_latents = self.num_channels
|
699 |
+
shape = (batch_size, num_frames, 64)
|
700 |
+
latents = randn_tensor(shape, generator=None, device=device, dtype=dtype)
|
701 |
+
|
702 |
+
|
703 |
+
|
704 |
+
latent_masks = torch.zeros(latents.shape[0], latents.shape[1], dtype=torch.int64, device=latents.device)
|
705 |
+
latent_masks[:,0:latent_length] = 2
|
706 |
+
if(scenario=='other_seg'):
|
707 |
+
latent_masks[:,0:incontext_length] = 1
|
708 |
+
|
709 |
+
|
710 |
+
|
711 |
+
quantized_bestrq_emb = (latent_masks > 0.5).unsqueeze(-1) * quantized_bestrq_emb \
|
712 |
+
+ (latent_masks < 0.5).unsqueeze(-1) * self.zero_cond_embedding1.reshape(1,1,1024)
|
713 |
+
true_latents = true_latents.permute(0,2,1).contiguous()
|
714 |
+
true_latents = self.normfeat.project_sample(true_latents)
|
715 |
+
true_latents = true_latents.permute(0,2,1).contiguous()
|
716 |
+
incontext_latents = true_latents * ((latent_masks > 0.5) * (latent_masks < 1.5)).unsqueeze(-1).float()
|
717 |
+
incontext_length = ((latent_masks > 0.5) * (latent_masks < 1.5)).sum(-1)[0]
|
718 |
+
|
719 |
+
|
720 |
+
attention_mask=(latent_masks > 0.5)
|
721 |
+
B, L = attention_mask.size()
|
722 |
+
attention_mask = attention_mask.view(B, 1, L)
|
723 |
+
attention_mask = attention_mask * attention_mask.transpose(-1, -2)
|
724 |
+
attention_mask = attention_mask.unsqueeze(1)
|
725 |
+
latent_mask_input = self.mask_emb(latent_masks)
|
726 |
+
|
727 |
+
if('spk' in additional_feats):
|
728 |
+
# additional_model_input = torch.cat([quantized_bestrq_middle, quantized_bestrq_last, spk_embeds],1)
|
729 |
+
additional_model_input = torch.cat([quantized_bestrq_emb, spk_embeds],1)
|
730 |
+
else:
|
731 |
+
# additional_model_input = torch.cat([quantized_bestrq_middle, quantized_bestrq_last],1)
|
732 |
+
additional_model_input = torch.cat([quantized_bestrq_emb],1)
|
733 |
+
|
734 |
+
temperature = 1.0
|
735 |
+
t_span = torch.linspace(0, 1, num_steps + 1, device=quantized_bestrq_emb.device)
|
736 |
+
latents = self.cfm_wrapper.solve_euler(latents * temperature, latent_mask_input,incontext_latents, incontext_length, t_span, additional_model_input,attention_mask, guidance_scale)
|
737 |
+
|
738 |
+
latents[:,0:incontext_length,:] = incontext_latents[:,0:incontext_length,:]
|
739 |
+
latents = latents.permute(0,2,1).contiguous()
|
740 |
+
latents = self.normfeat.return_sample(latents)
|
741 |
+
# latents = latents.permute(0,2,1).contiguous()
|
742 |
+
return latents
|
743 |
+
|
744 |
+
@torch.no_grad()
|
745 |
+
def inference(self, input_audios, lyric, true_latents, latent_length, additional_feats, guidance_scale=2, num_steps=20,
|
746 |
+
disable_progress=True,layer=5,scenario='start_seg',rvq_num=1):
|
747 |
+
codes, embeds, spk_embeds = self.fetch_codes(input_audios, additional_feats,layer,rvq_num)
|
748 |
+
|
749 |
+
latents = self.inference_codes(codes, spk_embeds, true_latents, latent_length, additional_feats, \
|
750 |
+
guidance_scale=guidance_scale, num_steps=num_steps, \
|
751 |
+
disable_progress=disable_progress,scenario=scenario)
|
752 |
+
return latents
|
753 |
+
|
754 |
+
@torch.no_grad()
|
755 |
+
def inference_rtf(self, input_audios, lyric, true_latents, latent_length, additional_feats, guidance_scale=2, num_steps=20,
|
756 |
+
disable_progress=True,layer=5,scenario='start_seg'):
|
757 |
+
codes, embeds, spk_embeds = self.fetch_codes(input_audios, additional_feats,layer)
|
758 |
+
import time
|
759 |
+
start = time.time()
|
760 |
+
latents = self.inference_codes(codes, spk_embeds, true_latents, latent_length, additional_feats, \
|
761 |
+
guidance_scale=guidance_scale, num_steps=num_steps, \
|
762 |
+
disable_progress=disable_progress,scenario=scenario)
|
763 |
+
return latents,time.time()-start
|
764 |
+
|
765 |
+
def prepare_latents(self, batch_size, num_frames, num_channels_latents, dtype, device):
|
766 |
+
divisor = 4
|
767 |
+
shape = (batch_size, num_channels_latents, num_frames, 32)
|
768 |
+
if(num_frames%divisor>0):
|
769 |
+
num_frames = round(num_frames/float(divisor))*divisor
|
770 |
+
shape = (batch_size, num_channels_latents, num_frames, 32)
|
771 |
+
latents = randn_tensor(shape, generator=None, device=device, dtype=dtype)
|
772 |
+
return latents
|
773 |
+
|
774 |
+
|
codeclm/tokenizer/Flow1dVAE/model_septoken.py
CHANGED
@@ -1,670 +1,670 @@
|
|
1 |
-
import yaml
|
2 |
-
import random
|
3 |
-
import inspect
|
4 |
-
import numpy as np
|
5 |
-
from tqdm import tqdm
|
6 |
-
import typing as tp
|
7 |
-
from abc import ABC
|
8 |
-
|
9 |
-
import torch
|
10 |
-
import torch.nn as nn
|
11 |
-
import torch.nn.functional as F
|
12 |
-
import torchaudio
|
13 |
-
|
14 |
-
from einops import repeat
|
15 |
-
from tools.torch_tools import wav_to_fbank
|
16 |
-
|
17 |
-
from diffusers.utils.torch_utils import randn_tensor
|
18 |
-
from transformers import HubertModel
|
19 |
-
from libs.rvq.descript_quantize3 import ResidualVectorQuantize
|
20 |
-
|
21 |
-
from models_gpt.models.gpt2_rope2_time_new_correct_mask_noncasual_reflow import GPT2Model
|
22 |
-
from models_gpt.models.gpt2_config import GPT2Config
|
23 |
-
|
24 |
-
from torch.cuda.amp import autocast
|
25 |
-
from our_MERT_BESTRQ.test import load_model
|
26 |
-
|
27 |
-
class HubertModelWithFinalProj(HubertModel):
|
28 |
-
def __init__(self, config):
|
29 |
-
super().__init__(config)
|
30 |
-
|
31 |
-
# The final projection layer is only used for backward compatibility.
|
32 |
-
# Following https://github.com/auspicious3000/contentvec/issues/6
|
33 |
-
# Remove this layer is necessary to achieve the desired outcome.
|
34 |
-
print("hidden_size:",config.hidden_size)
|
35 |
-
print("classifier_proj_size:",config.classifier_proj_size)
|
36 |
-
self.final_proj = nn.Linear(config.hidden_size, config.classifier_proj_size)
|
37 |
-
|
38 |
-
|
39 |
-
class SampleProcessor(torch.nn.Module):
|
40 |
-
def project_sample(self, x: torch.Tensor):
|
41 |
-
"""Project the original sample to the 'space' where the diffusion will happen."""
|
42 |
-
"""Project back from diffusion space to the actual sample space."""
|
43 |
-
return z
|
44 |
-
|
45 |
-
class Feature1DProcessor(SampleProcessor):
|
46 |
-
def __init__(self, dim: int = 100, power_std = 1., \
|
47 |
-
num_samples: int = 100_000, cal_num_frames: int = 600):
|
48 |
-
super().__init__()
|
49 |
-
|
50 |
-
self.num_samples = num_samples
|
51 |
-
self.dim = dim
|
52 |
-
self.power_std = power_std
|
53 |
-
self.cal_num_frames = cal_num_frames
|
54 |
-
self.register_buffer('counts', torch.zeros(1))
|
55 |
-
self.register_buffer('sum_x', torch.zeros(dim))
|
56 |
-
self.register_buffer('sum_x2', torch.zeros(dim))
|
57 |
-
self.register_buffer('sum_target_x2', torch.zeros(dim))
|
58 |
-
self.counts: torch.Tensor
|
59 |
-
self.sum_x: torch.Tensor
|
60 |
-
self.sum_x2: torch.Tensor
|
61 |
-
|
62 |
-
@property
|
63 |
-
def mean(self):
|
64 |
-
mean = self.sum_x / self.counts
|
65 |
-
if(self.counts < 10):
|
66 |
-
mean = torch.zeros_like(mean)
|
67 |
-
return mean
|
68 |
-
|
69 |
-
@property
|
70 |
-
def std(self):
|
71 |
-
std = (self.sum_x2 / self.counts - self.mean**2).clamp(min=0).sqrt()
|
72 |
-
if(self.counts < 10):
|
73 |
-
std = torch.ones_like(std)
|
74 |
-
return std
|
75 |
-
|
76 |
-
@property
|
77 |
-
def target_std(self):
|
78 |
-
return 1
|
79 |
-
|
80 |
-
def project_sample(self, x: torch.Tensor):
|
81 |
-
assert x.dim() == 3
|
82 |
-
if self.counts.item() < self.num_samples:
|
83 |
-
self.counts += len(x)
|
84 |
-
self.sum_x += x[:,:,0:self.cal_num_frames].mean(dim=(2,)).sum(dim=0)
|
85 |
-
self.sum_x2 += x[:,:,0:self.cal_num_frames].pow(2).mean(dim=(2,)).sum(dim=0)
|
86 |
-
rescale = (self.target_std / self.std.clamp(min=1e-12)) ** self.power_std # same output size
|
87 |
-
x = (x - self.mean.view(1, -1, 1)) * rescale.view(1, -1, 1)
|
88 |
-
return x
|
89 |
-
|
90 |
-
def return_sample(self, x: torch.Tensor):
|
91 |
-
assert x.dim() == 3
|
92 |
-
rescale = (self.std / self.target_std) ** self.power_std
|
93 |
-
x = x * rescale.view(1, -1, 1) + self.mean.view(1, -1, 1)
|
94 |
-
return x
|
95 |
-
|
96 |
-
def pad_or_tunc_tolen(prior_text_encoder_hidden_states, prior_text_mask, prior_prompt_embeds, len_size=77):
|
97 |
-
if(prior_text_encoder_hidden_states.shape[1]<len_size):
|
98 |
-
prior_text_encoder_hidden_states = torch.cat([prior_text_encoder_hidden_states, \
|
99 |
-
torch.zeros(prior_text_mask.shape[0], len_size-prior_text_mask.shape[1], \
|
100 |
-
prior_text_encoder_hidden_states.shape[2], device=prior_text_mask.device, \
|
101 |
-
dtype=prior_text_encoder_hidden_states.dtype)],1)
|
102 |
-
prior_text_mask = torch.cat([prior_text_mask, torch.zeros(prior_text_mask.shape[0], len_size-prior_text_mask.shape[1], device=prior_text_mask.device, dtype=prior_text_mask.dtype)],1)
|
103 |
-
else:
|
104 |
-
prior_text_encoder_hidden_states = prior_text_encoder_hidden_states[:,0:len_size]
|
105 |
-
prior_text_mask = prior_text_mask[:,0:len_size]
|
106 |
-
prior_text_encoder_hidden_states = prior_text_encoder_hidden_states.permute(0,2,1).contiguous()
|
107 |
-
return prior_text_encoder_hidden_states, prior_text_mask, prior_prompt_embeds
|
108 |
-
|
109 |
-
class BASECFM(torch.nn.Module, ABC):
|
110 |
-
def __init__(
|
111 |
-
self,
|
112 |
-
estimator,
|
113 |
-
mlp
|
114 |
-
):
|
115 |
-
super().__init__()
|
116 |
-
self.sigma_min = 1e-4
|
117 |
-
|
118 |
-
self.estimator = estimator
|
119 |
-
self.mlp = mlp
|
120 |
-
|
121 |
-
@torch.inference_mode()
|
122 |
-
def forward(self, mu, n_timesteps, temperature=1.0):
|
123 |
-
"""Forward diffusion
|
124 |
-
|
125 |
-
Args:
|
126 |
-
mu (torch.Tensor): output of encoder
|
127 |
-
shape: (batch_size, n_channels, mel_timesteps, n_feats)
|
128 |
-
n_timesteps (int): number of diffusion steps
|
129 |
-
temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
|
130 |
-
|
131 |
-
Returns:
|
132 |
-
sample: generated mel-spectrogram
|
133 |
-
shape: (batch_size, n_channels, mel_timesteps, n_feats)
|
134 |
-
"""
|
135 |
-
z = torch.randn_like(mu) * temperature
|
136 |
-
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device)
|
137 |
-
return self.solve_euler(z, t_span=t_span)
|
138 |
-
|
139 |
-
def solve_euler(self, x, latent_mask_input,incontext_x, incontext_length, t_span, mu,attention_mask, guidance_scale):
|
140 |
-
"""
|
141 |
-
Fixed euler solver for ODEs.
|
142 |
-
Args:
|
143 |
-
x (torch.Tensor): random noise
|
144 |
-
t_span (torch.Tensor): n_timesteps interpolated
|
145 |
-
shape: (n_timesteps + 1,)
|
146 |
-
mu (torch.Tensor): output of encoder
|
147 |
-
shape: (batch_size, n_channels, mel_timesteps, n_feats)
|
148 |
-
"""
|
149 |
-
t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
|
150 |
-
noise = x.clone()
|
151 |
-
|
152 |
-
# I am storing this because I can later plot it by putting a debugger here and saving it to a file
|
153 |
-
# Or in future might add like a return_all_steps flag
|
154 |
-
sol = []
|
155 |
-
|
156 |
-
for step in tqdm(range(1, len(t_span))):
|
157 |
-
x[:,0:incontext_length,:] = (1 - (1 - self.sigma_min) * t) * noise[:,0:incontext_length,:] + t * incontext_x[:,0:incontext_length,:]
|
158 |
-
if(guidance_scale > 1.0):
|
159 |
-
|
160 |
-
model_input = torch.cat([ \
|
161 |
-
torch.cat([latent_mask_input, latent_mask_input], 0), \
|
162 |
-
torch.cat([incontext_x, incontext_x], 0), \
|
163 |
-
torch.cat([torch.zeros_like(mu), mu], 0), \
|
164 |
-
torch.cat([x, x], 0), \
|
165 |
-
], 2)
|
166 |
-
timestep=t.unsqueeze(-1).repeat(2)
|
167 |
-
|
168 |
-
dphi_dt = self.estimator(inputs_embeds=model_input, attention_mask=attention_mask,time_step=timestep).last_hidden_state
|
169 |
-
dphi_dt_uncond, dhpi_dt_cond = dphi_dt.chunk(2,0)
|
170 |
-
dphi_dt = dphi_dt_uncond + guidance_scale * (dhpi_dt_cond - dphi_dt_uncond)
|
171 |
-
else:
|
172 |
-
model_input = torch.cat([latent_mask_input, incontext_x, mu, x], 2)
|
173 |
-
timestep=t.unsqueeze(-1)
|
174 |
-
dphi_dt = self.estimator(inputs_embeds=model_input, attention_mask=attention_mask,time_step=timestep).last_hidden_state
|
175 |
-
|
176 |
-
dphi_dt = dphi_dt[: ,:, -x.shape[2]:]
|
177 |
-
x = x + dt * dphi_dt
|
178 |
-
t = t + dt
|
179 |
-
sol.append(x)
|
180 |
-
if step < len(t_span) - 1:
|
181 |
-
dt = t_span[step + 1] - t
|
182 |
-
|
183 |
-
return sol[-1]
|
184 |
-
|
185 |
-
def projection_loss(self,hidden_proj, bestrq_emb):
|
186 |
-
bsz = hidden_proj.shape[0]
|
187 |
-
|
188 |
-
hidden_proj_normalized = F.normalize(hidden_proj, dim=-1)
|
189 |
-
bestrq_emb_normalized = F.normalize(bestrq_emb, dim=-1)
|
190 |
-
|
191 |
-
proj_loss = -(hidden_proj_normalized * bestrq_emb_normalized).sum(dim=-1)
|
192 |
-
proj_loss = 1+proj_loss.mean()
|
193 |
-
|
194 |
-
return proj_loss
|
195 |
-
|
196 |
-
def compute_loss(self, x1, mu, latent_masks,attention_mask,wav2vec_embeds, validation_mode=False):
|
197 |
-
"""Computes diffusion loss
|
198 |
-
|
199 |
-
Args:
|
200 |
-
x1 (torch.Tensor): Target
|
201 |
-
shape: (batch_size, n_channels, mel_timesteps, n_feats)
|
202 |
-
mu (torch.Tensor): output of encoder
|
203 |
-
shape: (batch_size, n_channels, mel_timesteps, n_feats)
|
204 |
-
|
205 |
-
Returns:
|
206 |
-
loss: conditional flow matching loss
|
207 |
-
y: conditional flow
|
208 |
-
shape: (batch_size, n_channels, mel_timesteps, n_feats)
|
209 |
-
"""
|
210 |
-
b = mu[0].shape[0]
|
211 |
-
len_x = x1.shape[2]
|
212 |
-
# random timestep
|
213 |
-
if(validation_mode):
|
214 |
-
t = torch.ones([b, 1, 1], device=mu[0].device, dtype=mu[0].dtype) * 0.5
|
215 |
-
else:
|
216 |
-
t = torch.rand([b, 1, 1], device=mu[0].device, dtype=mu[0].dtype)
|
217 |
-
# sample noise p(x_0)
|
218 |
-
z = torch.randn_like(x1)
|
219 |
-
|
220 |
-
y = (1 - (1 - self.sigma_min) * t) * z + t * x1
|
221 |
-
u = x1 - (1 - self.sigma_min) * z
|
222 |
-
model_input = torch.cat([*mu,y], 2)
|
223 |
-
t=t.squeeze(-1).squeeze(-1)
|
224 |
-
out = self.estimator(inputs_embeds=model_input, attention_mask=attention_mask,time_step=t,output_hidden_states=True)
|
225 |
-
hidden_layer_7 = out.hidden_states[7]
|
226 |
-
hidden_proj = self.mlp(hidden_layer_7)
|
227 |
-
out = out.last_hidden_state
|
228 |
-
out=out[:,:,-len_x:]
|
229 |
-
|
230 |
-
weight = (latent_masks > 1.5).unsqueeze(-1).repeat(1, 1, out.shape[-1]).float() + (latent_masks < 0.5).unsqueeze(-1).repeat(1, 1, out.shape[-1]).float() * 0.01
|
231 |
-
loss_re = F.mse_loss(out * weight, u * weight, reduction="sum") / weight.sum()
|
232 |
-
loss_cos = self.projection_loss(hidden_proj, wav2vec_embeds)
|
233 |
-
loss = loss_re + loss_cos * 0.5
|
234 |
-
return loss, loss_re, loss_cos
|
235 |
-
|
236 |
-
class PromptCondAudioDiffusion(nn.Module):
|
237 |
-
def __init__(
|
238 |
-
self,
|
239 |
-
num_channels,
|
240 |
-
unet_model_name=None,
|
241 |
-
unet_model_config_path=None,
|
242 |
-
snr_gamma=None,
|
243 |
-
uncondition=True,
|
244 |
-
out_paint=False,
|
245 |
-
):
|
246 |
-
super().__init__()
|
247 |
-
|
248 |
-
assert unet_model_name is not None or unet_model_config_path is not None, "Either UNet pretrain model name or a config file path is required"
|
249 |
-
|
250 |
-
self.unet_model_name = unet_model_name
|
251 |
-
self.unet_model_config_path = unet_model_config_path
|
252 |
-
self.snr_gamma = snr_gamma
|
253 |
-
self.uncondition = uncondition
|
254 |
-
self.num_channels = num_channels
|
255 |
-
|
256 |
-
# https://huggingface.co/docs/diffusers/v0.14.0/en/api/schedulers/overview
|
257 |
-
self.normfeat = Feature1DProcessor(dim=64)
|
258 |
-
|
259 |
-
self.sample_rate = 48000
|
260 |
-
self.num_samples_perseg = self.sample_rate * 20 // 1000
|
261 |
-
self.rsp48toclap = torchaudio.transforms.Resample(48000, 24000)
|
262 |
-
self.rsq48towav2vec = torchaudio.transforms.Resample(48000, 16000)
|
263 |
-
# self.wav2vec = Wav2Vec2BertModel.from_pretrained("facebook/w2v-bert-2.0", trust_remote_code=True)
|
264 |
-
# self.wav2vec_processor = AutoFeatureExtractor.from_pretrained("facebook/w2v-bert-2.0", trust_remote_code=True)
|
265 |
-
self.bestrq = load_model(
|
266 |
-
model_dir='codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq',
|
267 |
-
checkpoint_dir='ckpt/encode-s12k.pt',
|
268 |
-
)
|
269 |
-
self.rsq48tobestrq = torchaudio.transforms.Resample(48000, 24000)
|
270 |
-
self.rsq48tohubert = torchaudio.transforms.Resample(48000, 16000)
|
271 |
-
for v in self.bestrq.parameters():v.requires_grad = False
|
272 |
-
self.rvq_bestrq_emb = ResidualVectorQuantize(input_dim = 1024, n_codebooks = 1, codebook_size = 16_384, codebook_dim = 32, quantizer_dropout = 0.0, stale_tolerance=200)
|
273 |
-
self.rvq_bestrq_bgm_emb = ResidualVectorQuantize(input_dim = 1024, n_codebooks = 1, codebook_size = 16_384, codebook_dim = 32, quantizer_dropout = 0.0, stale_tolerance=200)
|
274 |
-
self.hubert = HubertModelWithFinalProj.from_pretrained("ckpt/models--lengyue233--content-vec-best/snapshots/c0b9ba13db21beaa4053faae94c102ebe326fd68")
|
275 |
-
for v in self.hubert.parameters():v.requires_grad = False
|
276 |
-
self.zero_cond_embedding1 = nn.Parameter(torch.randn(32*32,))
|
277 |
-
# self.xvecmodel = XVECModel()
|
278 |
-
config = GPT2Config(n_positions=1000,n_layer=16,n_head=20,n_embd=2200,n_inner=4400)
|
279 |
-
unet = GPT2Model(config)
|
280 |
-
mlp = nn.Sequential(
|
281 |
-
nn.Linear(2200, 1024),
|
282 |
-
nn.SiLU(),
|
283 |
-
nn.Linear(1024, 1024),
|
284 |
-
nn.SiLU(),
|
285 |
-
nn.Linear(1024, 768)
|
286 |
-
)
|
287 |
-
self.set_from = "random"
|
288 |
-
self.cfm_wrapper = BASECFM(unet, mlp)
|
289 |
-
self.mask_emb = torch.nn.Embedding(3, 24)
|
290 |
-
print("Transformer initialized from pretrain.")
|
291 |
-
torch.cuda.empty_cache()
|
292 |
-
|
293 |
-
def compute_snr(self, timesteps):
|
294 |
-
"""
|
295 |
-
Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
|
296 |
-
"""
|
297 |
-
alphas_cumprod = self.noise_scheduler.alphas_cumprod
|
298 |
-
sqrt_alphas_cumprod = alphas_cumprod**0.5
|
299 |
-
sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
|
300 |
-
|
301 |
-
# Expand the tensors.
|
302 |
-
# Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026
|
303 |
-
sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
|
304 |
-
while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape):
|
305 |
-
sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
|
306 |
-
alpha = sqrt_alphas_cumprod.expand(timesteps.shape)
|
307 |
-
|
308 |
-
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
|
309 |
-
while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape):
|
310 |
-
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None]
|
311 |
-
sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape)
|
312 |
-
|
313 |
-
# Compute SNR.
|
314 |
-
snr = (alpha / sigma) ** 2
|
315 |
-
return snr
|
316 |
-
|
317 |
-
def preprocess_audio(self, input_audios, threshold=0.9):
|
318 |
-
assert len(input_audios.shape) == 2, input_audios.shape
|
319 |
-
norm_value = torch.ones_like(input_audios[:,0])
|
320 |
-
max_volume = input_audios.abs().max(dim=-1)[0]
|
321 |
-
norm_value[max_volume>threshold] = max_volume[max_volume>threshold] / threshold
|
322 |
-
return input_audios/norm_value.unsqueeze(-1)
|
323 |
-
|
324 |
-
def extract_wav2vec_embeds(self, input_audios,output_len):
|
325 |
-
wav2vec_stride = 2
|
326 |
-
|
327 |
-
wav2vec_embeds = self.hubert(self.rsq48tohubert(input_audios), output_hidden_states=True).hidden_states # 1, 4096, 1024
|
328 |
-
wav2vec_embeds_last=wav2vec_embeds[-1]
|
329 |
-
wav2vec_embeds_last=torch.nn.functional.interpolate(wav2vec_embeds_last.permute(0, 2, 1), size=output_len, mode='linear', align_corners=False).permute(0, 2, 1)
|
330 |
-
return wav2vec_embeds_last
|
331 |
-
|
332 |
-
def extract_mert_embeds(self, input_audios):
|
333 |
-
prompt_stride = 3
|
334 |
-
inputs = self.clap_embd_extractor.mulan.audio.processor(self.rsp48toclap(input_audios), sampling_rate=self.clap_embd_extractor.mulan.audio.sr, return_tensors="pt")
|
335 |
-
input_values = inputs['input_values'].squeeze(0).to(input_audios.device, dtype = input_audios.dtype)
|
336 |
-
prompt_embeds = self.clap_embd_extractor.mulan.audio.model(input_values, output_hidden_states=True).hidden_states # batch_size, Time steps, 1024
|
337 |
-
mert_emb= prompt_embeds[-1]
|
338 |
-
mert_emb = torch.nn.functional.interpolate(mert_emb.permute(0, 2, 1), size=375, mode='linear', align_corners=False).permute(0, 2, 1)
|
339 |
-
|
340 |
-
return mert_emb
|
341 |
-
|
342 |
-
def extract_bestrq_embeds(self, input_audio_vocal_0,input_audio_vocal_1,layer):
|
343 |
-
input_wav_mean = (input_audio_vocal_0 + input_audio_vocal_1) / 2.0
|
344 |
-
input_wav_mean = self.bestrq(self.rsq48tobestrq(input_wav_mean), features_only = True)
|
345 |
-
layer_results = input_wav_mean['layer_results']
|
346 |
-
bestrq_emb = layer_results[layer]
|
347 |
-
bestrq_emb = bestrq_emb.permute(0,2,1).contiguous()
|
348 |
-
return bestrq_emb
|
349 |
-
|
350 |
-
|
351 |
-
def extract_spk_embeds(self, input_audios):
|
352 |
-
spk_embeds = self.xvecmodel(self.rsq48towav2vec(input_audios))
|
353 |
-
spk_embeds = self.spk_linear(spk_embeds).reshape(spk_embeds.shape[0], 16, 1, 32)
|
354 |
-
return spk_embeds
|
355 |
-
|
356 |
-
def extract_lyric_feats(self, lyric):
|
357 |
-
with torch.no_grad():
|
358 |
-
try:
|
359 |
-
text_encoder_hidden_states, text_mask, text_prompt_embeds = self.clap_embd_extractor(texts = lyric, return_one=False)
|
360 |
-
except:
|
361 |
-
text_encoder_hidden_states, text_mask, text_prompt_embeds = self.clap_embd_extractor(texts = [""] * len(lyric), return_one=False)
|
362 |
-
text_encoder_hidden_states = text_encoder_hidden_states.to(self.device)
|
363 |
-
text_mask = text_mask.to(self.device)
|
364 |
-
text_encoder_hidden_states, text_mask, text_prompt_embeds = \
|
365 |
-
pad_or_tunc_tolen(text_encoder_hidden_states, text_mask, text_prompt_embeds)
|
366 |
-
text_encoder_hidden_states = text_encoder_hidden_states.permute(0,2,1).contiguous()
|
367 |
-
return text_encoder_hidden_states, text_mask
|
368 |
-
|
369 |
-
def extract_energy_bar(self, input_audios):
|
370 |
-
if(input_audios.shape[-1] % self.num_samples_perseg > 0):
|
371 |
-
energy_bar = input_audios[:,:-1 * (input_audios.shape[-1] % self.num_samples_perseg)].reshape(input_audios.shape[0],-1,self.num_samples_perseg)
|
372 |
-
else:
|
373 |
-
energy_bar = input_audios.reshape(input_audios.shape[0],-1,self.num_samples_perseg)
|
374 |
-
energy_bar = (energy_bar.pow(2.0).mean(-1).sqrt() + 1e-6).log10() * 20 # B T
|
375 |
-
energy_bar = (energy_bar / 2.0 + 16).clamp(0,16).int()
|
376 |
-
energy_embedding = self.energy_embedding(energy_bar)
|
377 |
-
energy_embedding = energy_embedding.view(energy_embedding.shape[0], energy_embedding.shape[1] // 2, 2, 32).reshape(energy_embedding.shape[0], energy_embedding.shape[1] // 2, 64).permute(0,2,1) # b 128 t
|
378 |
-
return energy_embedding
|
379 |
-
|
380 |
-
def forward(self, input_audios_vocal,input_audios_bgm, lyric, latents, latent_masks, validation_mode=False, \
|
381 |
-
additional_feats = ['spk', 'lyric'], \
|
382 |
-
train_rvq=True, train_ssl=False,layer_vocal=7,layer_bgm=7):
|
383 |
-
if not hasattr(self,"device"):
|
384 |
-
self.device = input_audios_vocal.device
|
385 |
-
if not hasattr(self,"dtype"):
|
386 |
-
self.dtype = input_audios_vocal.dtype
|
387 |
-
device = self.device
|
388 |
-
input_audio_vocal_0 = input_audios_vocal[:,0,:]
|
389 |
-
input_audio_vocal_1 = input_audios_vocal[:,1,:]
|
390 |
-
input_audio_vocal_0 = self.preprocess_audio(input_audio_vocal_0)
|
391 |
-
input_audio_vocal_1 = self.preprocess_audio(input_audio_vocal_1)
|
392 |
-
input_audios_vocal_wav2vec = (input_audio_vocal_0 + input_audio_vocal_1) / 2.0
|
393 |
-
|
394 |
-
input_audio_bgm_0 = input_audios_bgm[:,0,:]
|
395 |
-
input_audio_bgm_1 = input_audios_bgm[:,1,:]
|
396 |
-
input_audio_bgm_0 = self.preprocess_audio(input_audio_bgm_0)
|
397 |
-
input_audio_bgm_1 = self.preprocess_audio(input_audio_bgm_1)
|
398 |
-
input_audios_bgm_wav2vec = (input_audio_bgm_0 + input_audio_bgm_1) / 2.0
|
399 |
-
|
400 |
-
if(train_ssl):
|
401 |
-
self.wav2vec.train()
|
402 |
-
wav2vec_embeds = self.extract_wav2vec_embeds(input_audios)
|
403 |
-
self.clap_embd_extractor.train()
|
404 |
-
prompt_embeds = self.extract_mert_embeds(input_audios)
|
405 |
-
if('spk' in additional_feats):
|
406 |
-
self.xvecmodel.train()
|
407 |
-
spk_embeds = self.extract_spk_embeds(input_audios).repeat(1,1,prompt_embeds.shape[-1]//2,1)
|
408 |
-
else:
|
409 |
-
with torch.no_grad():
|
410 |
-
with autocast(enabled=False):
|
411 |
-
bestrq_emb = self.extract_bestrq_embeds(input_audio_vocal_0,input_audio_vocal_1,layer_vocal)
|
412 |
-
bestrq_emb_bgm = self.extract_bestrq_embeds(input_audio_bgm_0,input_audio_bgm_1,layer_bgm)
|
413 |
-
# mert_emb = self.extract_mert_embeds(input_audios_mert)
|
414 |
-
output_len = bestrq_emb.shape[2]
|
415 |
-
wav2vec_embeds = self.extract_wav2vec_embeds(input_audios_vocal_wav2vec+input_audios_bgm_wav2vec,output_len)
|
416 |
-
|
417 |
-
|
418 |
-
bestrq_emb = bestrq_emb.detach()
|
419 |
-
bestrq_emb_bgm = bestrq_emb_bgm.detach()
|
420 |
-
|
421 |
-
if('lyric' in additional_feats):
|
422 |
-
text_encoder_hidden_states, text_mask = self.extract_lyric_feats(lyric)
|
423 |
-
else:
|
424 |
-
text_encoder_hidden_states, text_mask = None, None
|
425 |
-
|
426 |
-
|
427 |
-
if(train_rvq):
|
428 |
-
quantized_bestrq_emb, _, _, commitment_loss_bestrq_emb, codebook_loss_bestrq_emb,_ = self.rvq_bestrq_emb(bestrq_emb) # b,d,t
|
429 |
-
quantized_bestrq_emb_bgm, _, _, commitment_loss_bestrq_emb_bgm, codebook_loss_bestrq_emb_bgm,_ = self.rvq_bestrq_bgm_emb(bestrq_emb_bgm) # b,d,t
|
430 |
-
else:
|
431 |
-
bestrq_emb = bestrq_emb.float()
|
432 |
-
self.rvq_bestrq_emb.eval()
|
433 |
-
# with autocast(enabled=False):
|
434 |
-
quantized_bestrq_emb, _, _, commitment_loss_bestrq_emb, codebook_loss_bestrq_emb,_ = self.rvq_bestrq_emb(bestrq_emb) # b,d,t
|
435 |
-
commitment_loss_bestrq_emb = commitment_loss_bestrq_emb.detach()
|
436 |
-
codebook_loss_bestrq_emb = codebook_loss_bestrq_emb.detach()
|
437 |
-
quantized_bestrq_emb = quantized_bestrq_emb.detach()
|
438 |
-
|
439 |
-
commitment_loss = commitment_loss_bestrq_emb+commitment_loss_bestrq_emb_bgm
|
440 |
-
codebook_loss = codebook_loss_bestrq_emb+codebook_loss_bestrq_emb_bgm
|
441 |
-
|
442 |
-
|
443 |
-
alpha=1
|
444 |
-
quantized_bestrq_emb = quantized_bestrq_emb * alpha + bestrq_emb * (1-alpha)
|
445 |
-
quantized_bestrq_emb_bgm = quantized_bestrq_emb_bgm * alpha + bestrq_emb_bgm * (1-alpha)
|
446 |
-
|
447 |
-
|
448 |
-
|
449 |
-
|
450 |
-
scenario = np.random.choice(['start_seg', 'other_seg'])
|
451 |
-
if(scenario == 'other_seg'):
|
452 |
-
for binx in range(input_audios_vocal.shape[0]):
|
453 |
-
# latent_masks[binx,0:64] = 1
|
454 |
-
latent_masks[binx,0:random.randint(64,128)] = 1
|
455 |
-
quantized_bestrq_emb = quantized_bestrq_emb.permute(0,2,1).contiguous()
|
456 |
-
quantized_bestrq_emb_bgm = quantized_bestrq_emb_bgm.permute(0,2,1).contiguous()
|
457 |
-
quantized_bestrq_emb = (latent_masks > 0.5).unsqueeze(-1) * quantized_bestrq_emb \
|
458 |
-
+ (latent_masks < 0.5).unsqueeze(-1) * self.zero_cond_embedding1.reshape(1,1,1024)
|
459 |
-
quantized_bestrq_emb_bgm = (latent_masks > 0.5).unsqueeze(-1) * quantized_bestrq_emb_bgm \
|
460 |
-
+ (latent_masks < 0.5).unsqueeze(-1) * self.zero_cond_embedding1.reshape(1,1,1024)
|
461 |
-
|
462 |
-
|
463 |
-
|
464 |
-
|
465 |
-
if self.uncondition:
|
466 |
-
mask_indices = [k for k in range(quantized_bestrq_emb.shape[0]) if random.random() < 0.1]
|
467 |
-
if len(mask_indices) > 0:
|
468 |
-
quantized_bestrq_emb[mask_indices] = 0
|
469 |
-
quantized_bestrq_emb_bgm[mask_indices] = 0
|
470 |
-
latents = latents.permute(0,2,1).contiguous()
|
471 |
-
latents = self.normfeat.project_sample(latents)
|
472 |
-
latents = latents.permute(0,2,1).contiguous()
|
473 |
-
incontext_latents = latents * ((latent_masks > 0.5) * (latent_masks < 1.5)).unsqueeze(-1).float()
|
474 |
-
attention_mask=(latent_masks > 0.5)
|
475 |
-
B, L = attention_mask.size()
|
476 |
-
attention_mask = attention_mask.view(B, 1, L)
|
477 |
-
attention_mask = attention_mask * attention_mask.transpose(-1, -2)
|
478 |
-
attention_mask = attention_mask.unsqueeze(1)
|
479 |
-
latent_mask_input = self.mask_emb(latent_masks)
|
480 |
-
loss,loss_re, loss_cos = self.cfm_wrapper.compute_loss(latents, [latent_mask_input,incontext_latents, quantized_bestrq_emb,quantized_bestrq_emb_bgm], latent_masks,attention_mask,wav2vec_embeds, validation_mode=validation_mode)
|
481 |
-
return loss,loss_re, loss_cos, commitment_loss.mean(), codebook_loss.mean()
|
482 |
-
|
483 |
-
def init_device_dtype(self, device, dtype):
|
484 |
-
self.device = device
|
485 |
-
self.dtype = dtype
|
486 |
-
|
487 |
-
@torch.no_grad()
|
488 |
-
def fetch_codes(self, input_audios_vocal,input_audios_bgm, additional_feats,layer_vocal=7,layer_bgm=7):
|
489 |
-
input_audio_vocal_0 = input_audios_vocal[[0],:]
|
490 |
-
input_audio_vocal_1 = input_audios_vocal[[1],:]
|
491 |
-
input_audio_vocal_0 = self.preprocess_audio(input_audio_vocal_0)
|
492 |
-
input_audio_vocal_1 = self.preprocess_audio(input_audio_vocal_1)
|
493 |
-
input_audios_vocal_wav2vec = (input_audio_vocal_0 + input_audio_vocal_1) / 2.0
|
494 |
-
|
495 |
-
input_audio_bgm_0 = input_audios_bgm[[0],:]
|
496 |
-
input_audio_bgm_1 = input_audios_bgm[[1],:]
|
497 |
-
input_audio_bgm_0 = self.preprocess_audio(input_audio_bgm_0)
|
498 |
-
input_audio_bgm_1 = self.preprocess_audio(input_audio_bgm_1)
|
499 |
-
input_audios_bgm_wav2vec = (input_audio_bgm_0 + input_audio_bgm_1) / 2.0
|
500 |
-
|
501 |
-
self.bestrq.eval()
|
502 |
-
|
503 |
-
# bestrq_middle,bestrq_last = self.extract_bestrq_embeds(input_audios)
|
504 |
-
# bestrq_middle = bestrq_middle.detach()
|
505 |
-
# bestrq_last = bestrq_last.detach()
|
506 |
-
bestrq_emb = self.extract_bestrq_embeds(input_audio_vocal_0,input_audio_vocal_1,layer_vocal)
|
507 |
-
bestrq_emb = bestrq_emb.detach()
|
508 |
-
|
509 |
-
bestrq_emb_bgm = self.extract_bestrq_embeds(input_audio_bgm_0,input_audio_bgm_1,layer_bgm)
|
510 |
-
bestrq_emb_bgm = bestrq_emb_bgm.detach()
|
511 |
-
|
512 |
-
|
513 |
-
|
514 |
-
self.rvq_bestrq_emb.eval()
|
515 |
-
quantized_bestrq_emb, codes_bestrq_emb, *_ = self.rvq_bestrq_emb(bestrq_emb) # b,d,t
|
516 |
-
|
517 |
-
self.rvq_bestrq_bgm_emb.eval()
|
518 |
-
quantized_bestrq_emb_bgm, codes_bestrq_emb_bgm, *_ = self.rvq_bestrq_bgm_emb(bestrq_emb_bgm) # b,d,t
|
519 |
-
|
520 |
-
|
521 |
-
if('spk' in additional_feats):
|
522 |
-
self.xvecmodel.eval()
|
523 |
-
spk_embeds = self.extract_spk_embeds(input_audios)
|
524 |
-
else:
|
525 |
-
spk_embeds = None
|
526 |
-
|
527 |
-
# return [codes_prompt, codes_wav2vec], [prompt_embeds, wav2vec_embeds], spk_embeds
|
528 |
-
# return [codes_prompt_7, codes_prompt_13, codes_prompt_20, codes_wav2vec_half, codes_wav2vec_last], [prompt_embeds_7, prompt_embeds_13, prompt_embeds_20, wav2vec_embeds_half, wav2vec_embeds_last], spk_embeds
|
529 |
-
# return [codes_bestrq_middle, codes_bestrq_last], [bestrq_middle, bestrq_last], spk_embeds
|
530 |
-
return [codes_bestrq_emb,codes_bestrq_emb_bgm], [bestrq_emb,bestrq_emb_bgm], spk_embeds
|
531 |
-
# return [codes_prompt_13, codes_wav2vec_last], [prompt_embeds_13, wav2vec_embeds_last], spk_embeds
|
532 |
-
|
533 |
-
@torch.no_grad()
|
534 |
-
def fetch_codes_batch(self, input_audios_vocal, input_audios_bgm, additional_feats,layer_vocal=7,layer_bgm=7):
|
535 |
-
input_audio_vocal_0 = input_audios_vocal[:,0,:]
|
536 |
-
input_audio_vocal_1 = input_audios_vocal[:,1,:]
|
537 |
-
input_audio_vocal_0 = self.preprocess_audio(input_audio_vocal_0)
|
538 |
-
input_audio_vocal_1 = self.preprocess_audio(input_audio_vocal_1)
|
539 |
-
input_audios_vocal_wav2vec = (input_audio_vocal_0 + input_audio_vocal_1) / 2.0
|
540 |
-
|
541 |
-
input_audio_bgm_0 = input_audios_bgm[:,0,:]
|
542 |
-
input_audio_bgm_1 = input_audios_bgm[:,1,:]
|
543 |
-
input_audio_bgm_0 = self.preprocess_audio(input_audio_bgm_0)
|
544 |
-
input_audio_bgm_1 = self.preprocess_audio(input_audio_bgm_1)
|
545 |
-
input_audios_bgm_wav2vec = (input_audio_bgm_0 + input_audio_bgm_1) / 2.0
|
546 |
-
|
547 |
-
self.bestrq.eval()
|
548 |
-
|
549 |
-
# bestrq_middle,bestrq_last = self.extract_bestrq_embeds(input_audios)
|
550 |
-
# bestrq_middle = bestrq_middle.detach()
|
551 |
-
# bestrq_last = bestrq_last.detach()
|
552 |
-
bestrq_emb = self.extract_bestrq_embeds(input_audio_vocal_0,input_audio_vocal_1,layer_vocal)
|
553 |
-
bestrq_emb = bestrq_emb.detach()
|
554 |
-
|
555 |
-
bestrq_emb_bgm = self.extract_bestrq_embeds(input_audio_bgm_0,input_audio_bgm_1,layer_bgm)
|
556 |
-
bestrq_emb_bgm = bestrq_emb_bgm.detach()
|
557 |
-
|
558 |
-
|
559 |
-
|
560 |
-
self.rvq_bestrq_emb.eval()
|
561 |
-
quantized_bestrq_emb, codes_bestrq_emb, *_ = self.rvq_bestrq_emb(bestrq_emb) # b,d,t
|
562 |
-
|
563 |
-
self.rvq_bestrq_bgm_emb.eval()
|
564 |
-
quantized_bestrq_emb_bgm, codes_bestrq_emb_bgm, *_ = self.rvq_bestrq_bgm_emb(bestrq_emb_bgm) # b,d,t
|
565 |
-
|
566 |
-
|
567 |
-
if('spk' in additional_feats):
|
568 |
-
self.xvecmodel.eval()
|
569 |
-
spk_embeds = self.extract_spk_embeds(input_audios)
|
570 |
-
else:
|
571 |
-
spk_embeds = None
|
572 |
-
|
573 |
-
# return [codes_prompt, codes_wav2vec], [prompt_embeds, wav2vec_embeds], spk_embeds
|
574 |
-
# return [codes_prompt_7, codes_prompt_13, codes_prompt_20, codes_wav2vec_half, codes_wav2vec_last], [prompt_embeds_7, prompt_embeds_13, prompt_embeds_20, wav2vec_embeds_half, wav2vec_embeds_last], spk_embeds
|
575 |
-
# return [codes_bestrq_middle, codes_bestrq_last], [bestrq_middle, bestrq_last], spk_embeds
|
576 |
-
return [codes_bestrq_emb,codes_bestrq_emb_bgm], [bestrq_emb,bestrq_emb_bgm], spk_embeds
|
577 |
-
# return [codes_prompt_13, codes_wav2vec_last], [prompt_embeds_13, wav2vec_embeds_last], spk_embeds
|
578 |
-
|
579 |
-
|
580 |
-
@torch.no_grad()
|
581 |
-
def inference_codes(self, codes, spk_embeds, true_latents, latent_length, additional_feats,incontext_length=127,
|
582 |
-
guidance_scale=2, num_steps=20,
|
583 |
-
disable_progress=True, scenario='start_seg'):
|
584 |
-
classifier_free_guidance = guidance_scale > 1.0
|
585 |
-
device = self.device
|
586 |
-
dtype = self.dtype
|
587 |
-
# codes_bestrq_middle, codes_bestrq_last = codes
|
588 |
-
codes_bestrq_emb,codes_bestrq_emb_bgm = codes
|
589 |
-
|
590 |
-
|
591 |
-
batch_size = codes_bestrq_emb.shape[0]
|
592 |
-
|
593 |
-
|
594 |
-
quantized_bestrq_emb,_,_=self.rvq_bestrq_emb.from_codes(codes_bestrq_emb)
|
595 |
-
quantized_bestrq_emb_bgm,_,_=self.rvq_bestrq_bgm_emb.from_codes(codes_bestrq_emb_bgm)
|
596 |
-
quantized_bestrq_emb = quantized_bestrq_emb.permute(0,2,1).contiguous()
|
597 |
-
quantized_bestrq_emb_bgm = quantized_bestrq_emb_bgm.permute(0,2,1).contiguous()
|
598 |
-
if('spk' in additional_feats):
|
599 |
-
spk_embeds = spk_embeds.repeat(1,1,quantized_bestrq_emb.shape[-2],1).detach()
|
600 |
-
|
601 |
-
num_frames = quantized_bestrq_emb.shape[1]
|
602 |
-
|
603 |
-
num_channels_latents = self.num_channels
|
604 |
-
shape = (batch_size, num_frames, 64)
|
605 |
-
latents = randn_tensor(shape, generator=None, device=device, dtype=dtype)
|
606 |
-
|
607 |
-
|
608 |
-
|
609 |
-
latent_masks = torch.zeros(latents.shape[0], latents.shape[1], dtype=torch.int64, device=latents.device)
|
610 |
-
latent_masks[:,0:latent_length] = 2
|
611 |
-
if(scenario=='other_seg'):
|
612 |
-
latent_masks[:,0:incontext_length] = 1
|
613 |
-
|
614 |
-
|
615 |
-
|
616 |
-
quantized_bestrq_emb = (latent_masks > 0.5).unsqueeze(-1) * quantized_bestrq_emb \
|
617 |
-
+ (latent_masks < 0.5).unsqueeze(-1) * self.zero_cond_embedding1.reshape(1,1,1024)
|
618 |
-
quantized_bestrq_emb_bgm = (latent_masks > 0.5).unsqueeze(-1) * quantized_bestrq_emb_bgm \
|
619 |
-
+ (latent_masks < 0.5).unsqueeze(-1) * self.zero_cond_embedding1.reshape(1,1,1024)
|
620 |
-
true_latents = true_latents.permute(0,2,1).contiguous()
|
621 |
-
true_latents = self.normfeat.project_sample(true_latents)
|
622 |
-
true_latents = true_latents.permute(0,2,1).contiguous()
|
623 |
-
incontext_latents = true_latents * ((latent_masks > 0.5) * (latent_masks < 1.5)).unsqueeze(-1).float()
|
624 |
-
incontext_length = ((latent_masks > 0.5) * (latent_masks < 1.5)).sum(-1)[0]
|
625 |
-
|
626 |
-
|
627 |
-
attention_mask=(latent_masks > 0.5)
|
628 |
-
B, L = attention_mask.size()
|
629 |
-
attention_mask = attention_mask.view(B, 1, L)
|
630 |
-
attention_mask = attention_mask * attention_mask.transpose(-1, -2)
|
631 |
-
attention_mask = attention_mask.unsqueeze(1)
|
632 |
-
latent_mask_input = self.mask_emb(latent_masks)
|
633 |
-
|
634 |
-
if('spk' in additional_feats):
|
635 |
-
# additional_model_input = torch.cat([quantized_bestrq_middle, quantized_bestrq_last, spk_embeds],1)
|
636 |
-
additional_model_input = torch.cat([quantized_bestrq_emb,quantized_bestrq_emb_bgm, spk_embeds],2)
|
637 |
-
else:
|
638 |
-
# additional_model_input = torch.cat([quantized_bestrq_middle, quantized_bestrq_last],1)
|
639 |
-
additional_model_input = torch.cat([quantized_bestrq_emb,quantized_bestrq_emb_bgm],2)
|
640 |
-
|
641 |
-
temperature = 1.0
|
642 |
-
t_span = torch.linspace(0, 1, num_steps + 1, device=quantized_bestrq_emb.device)
|
643 |
-
latents = self.cfm_wrapper.solve_euler(latents * temperature, latent_mask_input,incontext_latents, incontext_length, t_span, additional_model_input,attention_mask, guidance_scale)
|
644 |
-
|
645 |
-
latents[:,0:incontext_length,:] = incontext_latents[:,0:incontext_length,:]
|
646 |
-
latents = latents.permute(0,2,1).contiguous()
|
647 |
-
latents = self.normfeat.return_sample(latents)
|
648 |
-
# latents = latents.permute(0,2,1).contiguous()
|
649 |
-
return latents
|
650 |
-
|
651 |
-
@torch.no_grad()
|
652 |
-
def inference(self, input_audios_vocal,input_audios_bgm, lyric, true_latents, latent_length, additional_feats, guidance_scale=2, num_steps=20,
|
653 |
-
disable_progress=True,layer_vocal=7,layer_bgm=3,scenario='start_seg'):
|
654 |
-
codes, embeds, spk_embeds = self.fetch_codes(input_audios_vocal,input_audios_bgm, additional_feats,layer_vocal,layer_bgm)
|
655 |
-
|
656 |
-
latents = self.inference_codes(codes, spk_embeds, true_latents, latent_length, additional_feats, \
|
657 |
-
guidance_scale=guidance_scale, num_steps=num_steps, \
|
658 |
-
disable_progress=disable_progress,scenario=scenario)
|
659 |
-
return latents
|
660 |
-
|
661 |
-
def prepare_latents(self, batch_size, num_frames, num_channels_latents, dtype, device):
|
662 |
-
divisor = 4
|
663 |
-
shape = (batch_size, num_channels_latents, num_frames, 32)
|
664 |
-
if(num_frames%divisor>0):
|
665 |
-
num_frames = round(num_frames/float(divisor))*divisor
|
666 |
-
shape = (batch_size, num_channels_latents, num_frames, 32)
|
667 |
-
latents = randn_tensor(shape, generator=None, device=device, dtype=dtype)
|
668 |
-
return latents
|
669 |
-
|
670 |
-
|
|
|
1 |
+
import yaml
|
2 |
+
import random
|
3 |
+
import inspect
|
4 |
+
import numpy as np
|
5 |
+
from tqdm import tqdm
|
6 |
+
import typing as tp
|
7 |
+
from abc import ABC
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
import torch.nn.functional as F
|
12 |
+
import torchaudio
|
13 |
+
|
14 |
+
from einops import repeat
|
15 |
+
from tools.torch_tools import wav_to_fbank
|
16 |
+
|
17 |
+
from diffusers.utils.torch_utils import randn_tensor
|
18 |
+
from transformers import HubertModel
|
19 |
+
from libs.rvq.descript_quantize3 import ResidualVectorQuantize
|
20 |
+
|
21 |
+
from models_gpt.models.gpt2_rope2_time_new_correct_mask_noncasual_reflow import GPT2Model
|
22 |
+
from models_gpt.models.gpt2_config import GPT2Config
|
23 |
+
|
24 |
+
from torch.cuda.amp import autocast
|
25 |
+
from our_MERT_BESTRQ.test import load_model
|
26 |
+
|
27 |
+
class HubertModelWithFinalProj(HubertModel):
|
28 |
+
def __init__(self, config):
|
29 |
+
super().__init__(config)
|
30 |
+
|
31 |
+
# The final projection layer is only used for backward compatibility.
|
32 |
+
# Following https://github.com/auspicious3000/contentvec/issues/6
|
33 |
+
# Remove this layer is necessary to achieve the desired outcome.
|
34 |
+
print("hidden_size:",config.hidden_size)
|
35 |
+
print("classifier_proj_size:",config.classifier_proj_size)
|
36 |
+
self.final_proj = nn.Linear(config.hidden_size, config.classifier_proj_size)
|
37 |
+
|
38 |
+
|
39 |
+
class SampleProcessor(torch.nn.Module):
|
40 |
+
def project_sample(self, x: torch.Tensor):
|
41 |
+
"""Project the original sample to the 'space' where the diffusion will happen."""
|
42 |
+
"""Project back from diffusion space to the actual sample space."""
|
43 |
+
return z
|
44 |
+
|
45 |
+
class Feature1DProcessor(SampleProcessor):
|
46 |
+
def __init__(self, dim: int = 100, power_std = 1., \
|
47 |
+
num_samples: int = 100_000, cal_num_frames: int = 600):
|
48 |
+
super().__init__()
|
49 |
+
|
50 |
+
self.num_samples = num_samples
|
51 |
+
self.dim = dim
|
52 |
+
self.power_std = power_std
|
53 |
+
self.cal_num_frames = cal_num_frames
|
54 |
+
self.register_buffer('counts', torch.zeros(1))
|
55 |
+
self.register_buffer('sum_x', torch.zeros(dim))
|
56 |
+
self.register_buffer('sum_x2', torch.zeros(dim))
|
57 |
+
self.register_buffer('sum_target_x2', torch.zeros(dim))
|
58 |
+
self.counts: torch.Tensor
|
59 |
+
self.sum_x: torch.Tensor
|
60 |
+
self.sum_x2: torch.Tensor
|
61 |
+
|
62 |
+
@property
|
63 |
+
def mean(self):
|
64 |
+
mean = self.sum_x / self.counts
|
65 |
+
if(self.counts < 10):
|
66 |
+
mean = torch.zeros_like(mean)
|
67 |
+
return mean
|
68 |
+
|
69 |
+
@property
|
70 |
+
def std(self):
|
71 |
+
std = (self.sum_x2 / self.counts - self.mean**2).clamp(min=0).sqrt()
|
72 |
+
if(self.counts < 10):
|
73 |
+
std = torch.ones_like(std)
|
74 |
+
return std
|
75 |
+
|
76 |
+
@property
|
77 |
+
def target_std(self):
|
78 |
+
return 1
|
79 |
+
|
80 |
+
def project_sample(self, x: torch.Tensor):
|
81 |
+
assert x.dim() == 3
|
82 |
+
if self.counts.item() < self.num_samples:
|
83 |
+
self.counts += len(x)
|
84 |
+
self.sum_x += x[:,:,0:self.cal_num_frames].mean(dim=(2,)).sum(dim=0)
|
85 |
+
self.sum_x2 += x[:,:,0:self.cal_num_frames].pow(2).mean(dim=(2,)).sum(dim=0)
|
86 |
+
rescale = (self.target_std / self.std.clamp(min=1e-12)) ** self.power_std # same output size
|
87 |
+
x = (x - self.mean.view(1, -1, 1)) * rescale.view(1, -1, 1)
|
88 |
+
return x
|
89 |
+
|
90 |
+
def return_sample(self, x: torch.Tensor):
|
91 |
+
assert x.dim() == 3
|
92 |
+
rescale = (self.std / self.target_std) ** self.power_std
|
93 |
+
x = x * rescale.view(1, -1, 1) + self.mean.view(1, -1, 1)
|
94 |
+
return x
|
95 |
+
|
96 |
+
def pad_or_tunc_tolen(prior_text_encoder_hidden_states, prior_text_mask, prior_prompt_embeds, len_size=77):
|
97 |
+
if(prior_text_encoder_hidden_states.shape[1]<len_size):
|
98 |
+
prior_text_encoder_hidden_states = torch.cat([prior_text_encoder_hidden_states, \
|
99 |
+
torch.zeros(prior_text_mask.shape[0], len_size-prior_text_mask.shape[1], \
|
100 |
+
prior_text_encoder_hidden_states.shape[2], device=prior_text_mask.device, \
|
101 |
+
dtype=prior_text_encoder_hidden_states.dtype)],1)
|
102 |
+
prior_text_mask = torch.cat([prior_text_mask, torch.zeros(prior_text_mask.shape[0], len_size-prior_text_mask.shape[1], device=prior_text_mask.device, dtype=prior_text_mask.dtype)],1)
|
103 |
+
else:
|
104 |
+
prior_text_encoder_hidden_states = prior_text_encoder_hidden_states[:,0:len_size]
|
105 |
+
prior_text_mask = prior_text_mask[:,0:len_size]
|
106 |
+
prior_text_encoder_hidden_states = prior_text_encoder_hidden_states.permute(0,2,1).contiguous()
|
107 |
+
return prior_text_encoder_hidden_states, prior_text_mask, prior_prompt_embeds
|
108 |
+
|
109 |
+
class BASECFM(torch.nn.Module, ABC):
|
110 |
+
def __init__(
|
111 |
+
self,
|
112 |
+
estimator,
|
113 |
+
mlp
|
114 |
+
):
|
115 |
+
super().__init__()
|
116 |
+
self.sigma_min = 1e-4
|
117 |
+
|
118 |
+
self.estimator = estimator
|
119 |
+
self.mlp = mlp
|
120 |
+
|
121 |
+
@torch.inference_mode()
|
122 |
+
def forward(self, mu, n_timesteps, temperature=1.0):
|
123 |
+
"""Forward diffusion
|
124 |
+
|
125 |
+
Args:
|
126 |
+
mu (torch.Tensor): output of encoder
|
127 |
+
shape: (batch_size, n_channels, mel_timesteps, n_feats)
|
128 |
+
n_timesteps (int): number of diffusion steps
|
129 |
+
temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
|
130 |
+
|
131 |
+
Returns:
|
132 |
+
sample: generated mel-spectrogram
|
133 |
+
shape: (batch_size, n_channels, mel_timesteps, n_feats)
|
134 |
+
"""
|
135 |
+
z = torch.randn_like(mu) * temperature
|
136 |
+
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device)
|
137 |
+
return self.solve_euler(z, t_span=t_span)
|
138 |
+
|
139 |
+
def solve_euler(self, x, latent_mask_input,incontext_x, incontext_length, t_span, mu,attention_mask, guidance_scale):
|
140 |
+
"""
|
141 |
+
Fixed euler solver for ODEs.
|
142 |
+
Args:
|
143 |
+
x (torch.Tensor): random noise
|
144 |
+
t_span (torch.Tensor): n_timesteps interpolated
|
145 |
+
shape: (n_timesteps + 1,)
|
146 |
+
mu (torch.Tensor): output of encoder
|
147 |
+
shape: (batch_size, n_channels, mel_timesteps, n_feats)
|
148 |
+
"""
|
149 |
+
t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
|
150 |
+
noise = x.clone()
|
151 |
+
|
152 |
+
# I am storing this because I can later plot it by putting a debugger here and saving it to a file
|
153 |
+
# Or in future might add like a return_all_steps flag
|
154 |
+
sol = []
|
155 |
+
|
156 |
+
for step in tqdm(range(1, len(t_span))):
|
157 |
+
x[:,0:incontext_length,:] = (1 - (1 - self.sigma_min) * t) * noise[:,0:incontext_length,:] + t * incontext_x[:,0:incontext_length,:]
|
158 |
+
if(guidance_scale > 1.0):
|
159 |
+
|
160 |
+
model_input = torch.cat([ \
|
161 |
+
torch.cat([latent_mask_input, latent_mask_input], 0), \
|
162 |
+
torch.cat([incontext_x, incontext_x], 0), \
|
163 |
+
torch.cat([torch.zeros_like(mu), mu], 0), \
|
164 |
+
torch.cat([x, x], 0), \
|
165 |
+
], 2)
|
166 |
+
timestep=t.unsqueeze(-1).repeat(2)
|
167 |
+
|
168 |
+
dphi_dt = self.estimator(inputs_embeds=model_input, attention_mask=attention_mask,time_step=timestep).last_hidden_state
|
169 |
+
dphi_dt_uncond, dhpi_dt_cond = dphi_dt.chunk(2,0)
|
170 |
+
dphi_dt = dphi_dt_uncond + guidance_scale * (dhpi_dt_cond - dphi_dt_uncond)
|
171 |
+
else:
|
172 |
+
model_input = torch.cat([latent_mask_input, incontext_x, mu, x], 2)
|
173 |
+
timestep=t.unsqueeze(-1)
|
174 |
+
dphi_dt = self.estimator(inputs_embeds=model_input, attention_mask=attention_mask,time_step=timestep).last_hidden_state
|
175 |
+
|
176 |
+
dphi_dt = dphi_dt[: ,:, -x.shape[2]:]
|
177 |
+
x = x + dt * dphi_dt
|
178 |
+
t = t + dt
|
179 |
+
sol.append(x)
|
180 |
+
if step < len(t_span) - 1:
|
181 |
+
dt = t_span[step + 1] - t
|
182 |
+
|
183 |
+
return sol[-1]
|
184 |
+
|
185 |
+
def projection_loss(self,hidden_proj, bestrq_emb):
|
186 |
+
bsz = hidden_proj.shape[0]
|
187 |
+
|
188 |
+
hidden_proj_normalized = F.normalize(hidden_proj, dim=-1)
|
189 |
+
bestrq_emb_normalized = F.normalize(bestrq_emb, dim=-1)
|
190 |
+
|
191 |
+
proj_loss = -(hidden_proj_normalized * bestrq_emb_normalized).sum(dim=-1)
|
192 |
+
proj_loss = 1+proj_loss.mean()
|
193 |
+
|
194 |
+
return proj_loss
|
195 |
+
|
196 |
+
def compute_loss(self, x1, mu, latent_masks,attention_mask,wav2vec_embeds, validation_mode=False):
|
197 |
+
"""Computes diffusion loss
|
198 |
+
|
199 |
+
Args:
|
200 |
+
x1 (torch.Tensor): Target
|
201 |
+
shape: (batch_size, n_channels, mel_timesteps, n_feats)
|
202 |
+
mu (torch.Tensor): output of encoder
|
203 |
+
shape: (batch_size, n_channels, mel_timesteps, n_feats)
|
204 |
+
|
205 |
+
Returns:
|
206 |
+
loss: conditional flow matching loss
|
207 |
+
y: conditional flow
|
208 |
+
shape: (batch_size, n_channels, mel_timesteps, n_feats)
|
209 |
+
"""
|
210 |
+
b = mu[0].shape[0]
|
211 |
+
len_x = x1.shape[2]
|
212 |
+
# random timestep
|
213 |
+
if(validation_mode):
|
214 |
+
t = torch.ones([b, 1, 1], device=mu[0].device, dtype=mu[0].dtype) * 0.5
|
215 |
+
else:
|
216 |
+
t = torch.rand([b, 1, 1], device=mu[0].device, dtype=mu[0].dtype)
|
217 |
+
# sample noise p(x_0)
|
218 |
+
z = torch.randn_like(x1)
|
219 |
+
|
220 |
+
y = (1 - (1 - self.sigma_min) * t) * z + t * x1
|
221 |
+
u = x1 - (1 - self.sigma_min) * z
|
222 |
+
model_input = torch.cat([*mu,y], 2)
|
223 |
+
t=t.squeeze(-1).squeeze(-1)
|
224 |
+
out = self.estimator(inputs_embeds=model_input, attention_mask=attention_mask,time_step=t,output_hidden_states=True)
|
225 |
+
hidden_layer_7 = out.hidden_states[7]
|
226 |
+
hidden_proj = self.mlp(hidden_layer_7)
|
227 |
+
out = out.last_hidden_state
|
228 |
+
out=out[:,:,-len_x:]
|
229 |
+
|
230 |
+
weight = (latent_masks > 1.5).unsqueeze(-1).repeat(1, 1, out.shape[-1]).float() + (latent_masks < 0.5).unsqueeze(-1).repeat(1, 1, out.shape[-1]).float() * 0.01
|
231 |
+
loss_re = F.mse_loss(out * weight, u * weight, reduction="sum") / weight.sum()
|
232 |
+
loss_cos = self.projection_loss(hidden_proj, wav2vec_embeds)
|
233 |
+
loss = loss_re + loss_cos * 0.5
|
234 |
+
return loss, loss_re, loss_cos
|
235 |
+
|
236 |
+
class PromptCondAudioDiffusion(nn.Module):
|
237 |
+
def __init__(
|
238 |
+
self,
|
239 |
+
num_channels,
|
240 |
+
unet_model_name=None,
|
241 |
+
unet_model_config_path=None,
|
242 |
+
snr_gamma=None,
|
243 |
+
uncondition=True,
|
244 |
+
out_paint=False,
|
245 |
+
):
|
246 |
+
super().__init__()
|
247 |
+
|
248 |
+
assert unet_model_name is not None or unet_model_config_path is not None, "Either UNet pretrain model name or a config file path is required"
|
249 |
+
|
250 |
+
self.unet_model_name = unet_model_name
|
251 |
+
self.unet_model_config_path = unet_model_config_path
|
252 |
+
self.snr_gamma = snr_gamma
|
253 |
+
self.uncondition = uncondition
|
254 |
+
self.num_channels = num_channels
|
255 |
+
|
256 |
+
# https://huggingface.co/docs/diffusers/v0.14.0/en/api/schedulers/overview
|
257 |
+
self.normfeat = Feature1DProcessor(dim=64)
|
258 |
+
|
259 |
+
self.sample_rate = 48000
|
260 |
+
self.num_samples_perseg = self.sample_rate * 20 // 1000
|
261 |
+
self.rsp48toclap = torchaudio.transforms.Resample(48000, 24000)
|
262 |
+
self.rsq48towav2vec = torchaudio.transforms.Resample(48000, 16000)
|
263 |
+
# self.wav2vec = Wav2Vec2BertModel.from_pretrained("facebook/w2v-bert-2.0", trust_remote_code=True)
|
264 |
+
# self.wav2vec_processor = AutoFeatureExtractor.from_pretrained("facebook/w2v-bert-2.0", trust_remote_code=True)
|
265 |
+
self.bestrq = load_model(
|
266 |
+
model_dir='codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq',
|
267 |
+
checkpoint_dir='ckpt/encode-s12k.pt',
|
268 |
+
)
|
269 |
+
self.rsq48tobestrq = torchaudio.transforms.Resample(48000, 24000)
|
270 |
+
self.rsq48tohubert = torchaudio.transforms.Resample(48000, 16000)
|
271 |
+
for v in self.bestrq.parameters():v.requires_grad = False
|
272 |
+
self.rvq_bestrq_emb = ResidualVectorQuantize(input_dim = 1024, n_codebooks = 1, codebook_size = 16_384, codebook_dim = 32, quantizer_dropout = 0.0, stale_tolerance=200)
|
273 |
+
self.rvq_bestrq_bgm_emb = ResidualVectorQuantize(input_dim = 1024, n_codebooks = 1, codebook_size = 16_384, codebook_dim = 32, quantizer_dropout = 0.0, stale_tolerance=200)
|
274 |
+
self.hubert = HubertModelWithFinalProj.from_pretrained("ckpt/models--lengyue233--content-vec-best/snapshots/c0b9ba13db21beaa4053faae94c102ebe326fd68")
|
275 |
+
for v in self.hubert.parameters():v.requires_grad = False
|
276 |
+
self.zero_cond_embedding1 = nn.Parameter(torch.randn(32*32,))
|
277 |
+
# self.xvecmodel = XVECModel()
|
278 |
+
config = GPT2Config(n_positions=1000,n_layer=16,n_head=20,n_embd=2200,n_inner=4400)
|
279 |
+
unet = GPT2Model(config)
|
280 |
+
mlp = nn.Sequential(
|
281 |
+
nn.Linear(2200, 1024),
|
282 |
+
nn.SiLU(),
|
283 |
+
nn.Linear(1024, 1024),
|
284 |
+
nn.SiLU(),
|
285 |
+
nn.Linear(1024, 768)
|
286 |
+
)
|
287 |
+
self.set_from = "random"
|
288 |
+
self.cfm_wrapper = BASECFM(unet, mlp)
|
289 |
+
self.mask_emb = torch.nn.Embedding(3, 24)
|
290 |
+
print("Transformer initialized from pretrain.")
|
291 |
+
torch.cuda.empty_cache()
|
292 |
+
|
293 |
+
def compute_snr(self, timesteps):
|
294 |
+
"""
|
295 |
+
Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
|
296 |
+
"""
|
297 |
+
alphas_cumprod = self.noise_scheduler.alphas_cumprod
|
298 |
+
sqrt_alphas_cumprod = alphas_cumprod**0.5
|
299 |
+
sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
|
300 |
+
|
301 |
+
# Expand the tensors.
|
302 |
+
# Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026
|
303 |
+
sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
|
304 |
+
while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape):
|
305 |
+
sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
|
306 |
+
alpha = sqrt_alphas_cumprod.expand(timesteps.shape)
|
307 |
+
|
308 |
+
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
|
309 |
+
while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape):
|
310 |
+
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None]
|
311 |
+
sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape)
|
312 |
+
|
313 |
+
# Compute SNR.
|
314 |
+
snr = (alpha / sigma) ** 2
|
315 |
+
return snr
|
316 |
+
|
317 |
+
def preprocess_audio(self, input_audios, threshold=0.9):
|
318 |
+
assert len(input_audios.shape) == 2, input_audios.shape
|
319 |
+
norm_value = torch.ones_like(input_audios[:,0])
|
320 |
+
max_volume = input_audios.abs().max(dim=-1)[0]
|
321 |
+
norm_value[max_volume>threshold] = max_volume[max_volume>threshold] / threshold
|
322 |
+
return input_audios/norm_value.unsqueeze(-1)
|
323 |
+
|
324 |
+
def extract_wav2vec_embeds(self, input_audios,output_len):
|
325 |
+
wav2vec_stride = 2
|
326 |
+
|
327 |
+
wav2vec_embeds = self.hubert(self.rsq48tohubert(input_audios), output_hidden_states=True).hidden_states # 1, 4096, 1024
|
328 |
+
wav2vec_embeds_last=wav2vec_embeds[-1]
|
329 |
+
wav2vec_embeds_last=torch.nn.functional.interpolate(wav2vec_embeds_last.permute(0, 2, 1), size=output_len, mode='linear', align_corners=False).permute(0, 2, 1)
|
330 |
+
return wav2vec_embeds_last
|
331 |
+
|
332 |
+
def extract_mert_embeds(self, input_audios):
|
333 |
+
prompt_stride = 3
|
334 |
+
inputs = self.clap_embd_extractor.mulan.audio.processor(self.rsp48toclap(input_audios), sampling_rate=self.clap_embd_extractor.mulan.audio.sr, return_tensors="pt")
|
335 |
+
input_values = inputs['input_values'].squeeze(0).to(input_audios.device, dtype = input_audios.dtype)
|
336 |
+
prompt_embeds = self.clap_embd_extractor.mulan.audio.model(input_values, output_hidden_states=True).hidden_states # batch_size, Time steps, 1024
|
337 |
+
mert_emb= prompt_embeds[-1]
|
338 |
+
mert_emb = torch.nn.functional.interpolate(mert_emb.permute(0, 2, 1), size=375, mode='linear', align_corners=False).permute(0, 2, 1)
|
339 |
+
|
340 |
+
return mert_emb
|
341 |
+
|
342 |
+
def extract_bestrq_embeds(self, input_audio_vocal_0,input_audio_vocal_1,layer):
|
343 |
+
input_wav_mean = (input_audio_vocal_0 + input_audio_vocal_1) / 2.0
|
344 |
+
input_wav_mean = self.bestrq(self.rsq48tobestrq(input_wav_mean), features_only = True)
|
345 |
+
layer_results = input_wav_mean['layer_results']
|
346 |
+
bestrq_emb = layer_results[layer]
|
347 |
+
bestrq_emb = bestrq_emb.permute(0,2,1).contiguous()
|
348 |
+
return bestrq_emb
|
349 |
+
|
350 |
+
|
351 |
+
def extract_spk_embeds(self, input_audios):
|
352 |
+
spk_embeds = self.xvecmodel(self.rsq48towav2vec(input_audios))
|
353 |
+
spk_embeds = self.spk_linear(spk_embeds).reshape(spk_embeds.shape[0], 16, 1, 32)
|
354 |
+
return spk_embeds
|
355 |
+
|
356 |
+
def extract_lyric_feats(self, lyric):
|
357 |
+
with torch.no_grad():
|
358 |
+
try:
|
359 |
+
text_encoder_hidden_states, text_mask, text_prompt_embeds = self.clap_embd_extractor(texts = lyric, return_one=False)
|
360 |
+
except:
|
361 |
+
text_encoder_hidden_states, text_mask, text_prompt_embeds = self.clap_embd_extractor(texts = [""] * len(lyric), return_one=False)
|
362 |
+
text_encoder_hidden_states = text_encoder_hidden_states.to(self.device)
|
363 |
+
text_mask = text_mask.to(self.device)
|
364 |
+
text_encoder_hidden_states, text_mask, text_prompt_embeds = \
|
365 |
+
pad_or_tunc_tolen(text_encoder_hidden_states, text_mask, text_prompt_embeds)
|
366 |
+
text_encoder_hidden_states = text_encoder_hidden_states.permute(0,2,1).contiguous()
|
367 |
+
return text_encoder_hidden_states, text_mask
|
368 |
+
|
369 |
+
def extract_energy_bar(self, input_audios):
|
370 |
+
if(input_audios.shape[-1] % self.num_samples_perseg > 0):
|
371 |
+
energy_bar = input_audios[:,:-1 * (input_audios.shape[-1] % self.num_samples_perseg)].reshape(input_audios.shape[0],-1,self.num_samples_perseg)
|
372 |
+
else:
|
373 |
+
energy_bar = input_audios.reshape(input_audios.shape[0],-1,self.num_samples_perseg)
|
374 |
+
energy_bar = (energy_bar.pow(2.0).mean(-1).sqrt() + 1e-6).log10() * 20 # B T
|
375 |
+
energy_bar = (energy_bar / 2.0 + 16).clamp(0,16).int()
|
376 |
+
energy_embedding = self.energy_embedding(energy_bar)
|
377 |
+
energy_embedding = energy_embedding.view(energy_embedding.shape[0], energy_embedding.shape[1] // 2, 2, 32).reshape(energy_embedding.shape[0], energy_embedding.shape[1] // 2, 64).permute(0,2,1) # b 128 t
|
378 |
+
return energy_embedding
|
379 |
+
|
380 |
+
def forward(self, input_audios_vocal,input_audios_bgm, lyric, latents, latent_masks, validation_mode=False, \
|
381 |
+
additional_feats = ['spk', 'lyric'], \
|
382 |
+
train_rvq=True, train_ssl=False,layer_vocal=7,layer_bgm=7):
|
383 |
+
if not hasattr(self,"device"):
|
384 |
+
self.device = input_audios_vocal.device
|
385 |
+
if not hasattr(self,"dtype"):
|
386 |
+
self.dtype = input_audios_vocal.dtype
|
387 |
+
device = self.device
|
388 |
+
input_audio_vocal_0 = input_audios_vocal[:,0,:]
|
389 |
+
input_audio_vocal_1 = input_audios_vocal[:,1,:]
|
390 |
+
input_audio_vocal_0 = self.preprocess_audio(input_audio_vocal_0)
|
391 |
+
input_audio_vocal_1 = self.preprocess_audio(input_audio_vocal_1)
|
392 |
+
input_audios_vocal_wav2vec = (input_audio_vocal_0 + input_audio_vocal_1) / 2.0
|
393 |
+
|
394 |
+
input_audio_bgm_0 = input_audios_bgm[:,0,:]
|
395 |
+
input_audio_bgm_1 = input_audios_bgm[:,1,:]
|
396 |
+
input_audio_bgm_0 = self.preprocess_audio(input_audio_bgm_0)
|
397 |
+
input_audio_bgm_1 = self.preprocess_audio(input_audio_bgm_1)
|
398 |
+
input_audios_bgm_wav2vec = (input_audio_bgm_0 + input_audio_bgm_1) / 2.0
|
399 |
+
|
400 |
+
if(train_ssl):
|
401 |
+
self.wav2vec.train()
|
402 |
+
wav2vec_embeds = self.extract_wav2vec_embeds(input_audios)
|
403 |
+
self.clap_embd_extractor.train()
|
404 |
+
prompt_embeds = self.extract_mert_embeds(input_audios)
|
405 |
+
if('spk' in additional_feats):
|
406 |
+
self.xvecmodel.train()
|
407 |
+
spk_embeds = self.extract_spk_embeds(input_audios).repeat(1,1,prompt_embeds.shape[-1]//2,1)
|
408 |
+
else:
|
409 |
+
with torch.no_grad():
|
410 |
+
with autocast(enabled=False):
|
411 |
+
bestrq_emb = self.extract_bestrq_embeds(input_audio_vocal_0,input_audio_vocal_1,layer_vocal)
|
412 |
+
bestrq_emb_bgm = self.extract_bestrq_embeds(input_audio_bgm_0,input_audio_bgm_1,layer_bgm)
|
413 |
+
# mert_emb = self.extract_mert_embeds(input_audios_mert)
|
414 |
+
output_len = bestrq_emb.shape[2]
|
415 |
+
wav2vec_embeds = self.extract_wav2vec_embeds(input_audios_vocal_wav2vec+input_audios_bgm_wav2vec,output_len)
|
416 |
+
|
417 |
+
|
418 |
+
bestrq_emb = bestrq_emb.detach()
|
419 |
+
bestrq_emb_bgm = bestrq_emb_bgm.detach()
|
420 |
+
|
421 |
+
if('lyric' in additional_feats):
|
422 |
+
text_encoder_hidden_states, text_mask = self.extract_lyric_feats(lyric)
|
423 |
+
else:
|
424 |
+
text_encoder_hidden_states, text_mask = None, None
|
425 |
+
|
426 |
+
|
427 |
+
if(train_rvq):
|
428 |
+
quantized_bestrq_emb, _, _, commitment_loss_bestrq_emb, codebook_loss_bestrq_emb,_ = self.rvq_bestrq_emb(bestrq_emb) # b,d,t
|
429 |
+
quantized_bestrq_emb_bgm, _, _, commitment_loss_bestrq_emb_bgm, codebook_loss_bestrq_emb_bgm,_ = self.rvq_bestrq_bgm_emb(bestrq_emb_bgm) # b,d,t
|
430 |
+
else:
|
431 |
+
bestrq_emb = bestrq_emb.float()
|
432 |
+
self.rvq_bestrq_emb.eval()
|
433 |
+
# with autocast(enabled=False):
|
434 |
+
quantized_bestrq_emb, _, _, commitment_loss_bestrq_emb, codebook_loss_bestrq_emb,_ = self.rvq_bestrq_emb(bestrq_emb) # b,d,t
|
435 |
+
commitment_loss_bestrq_emb = commitment_loss_bestrq_emb.detach()
|
436 |
+
codebook_loss_bestrq_emb = codebook_loss_bestrq_emb.detach()
|
437 |
+
quantized_bestrq_emb = quantized_bestrq_emb.detach()
|
438 |
+
|
439 |
+
commitment_loss = commitment_loss_bestrq_emb+commitment_loss_bestrq_emb_bgm
|
440 |
+
codebook_loss = codebook_loss_bestrq_emb+codebook_loss_bestrq_emb_bgm
|
441 |
+
|
442 |
+
|
443 |
+
alpha=1
|
444 |
+
quantized_bestrq_emb = quantized_bestrq_emb * alpha + bestrq_emb * (1-alpha)
|
445 |
+
quantized_bestrq_emb_bgm = quantized_bestrq_emb_bgm * alpha + bestrq_emb_bgm * (1-alpha)
|
446 |
+
|
447 |
+
|
448 |
+
|
449 |
+
|
450 |
+
scenario = np.random.choice(['start_seg', 'other_seg'])
|
451 |
+
if(scenario == 'other_seg'):
|
452 |
+
for binx in range(input_audios_vocal.shape[0]):
|
453 |
+
# latent_masks[binx,0:64] = 1
|
454 |
+
latent_masks[binx,0:random.randint(64,128)] = 1
|
455 |
+
quantized_bestrq_emb = quantized_bestrq_emb.permute(0,2,1).contiguous()
|
456 |
+
quantized_bestrq_emb_bgm = quantized_bestrq_emb_bgm.permute(0,2,1).contiguous()
|
457 |
+
quantized_bestrq_emb = (latent_masks > 0.5).unsqueeze(-1) * quantized_bestrq_emb \
|
458 |
+
+ (latent_masks < 0.5).unsqueeze(-1) * self.zero_cond_embedding1.reshape(1,1,1024)
|
459 |
+
quantized_bestrq_emb_bgm = (latent_masks > 0.5).unsqueeze(-1) * quantized_bestrq_emb_bgm \
|
460 |
+
+ (latent_masks < 0.5).unsqueeze(-1) * self.zero_cond_embedding1.reshape(1,1,1024)
|
461 |
+
|
462 |
+
|
463 |
+
|
464 |
+
|
465 |
+
if self.uncondition:
|
466 |
+
mask_indices = [k for k in range(quantized_bestrq_emb.shape[0]) if random.random() < 0.1]
|
467 |
+
if len(mask_indices) > 0:
|
468 |
+
quantized_bestrq_emb[mask_indices] = 0
|
469 |
+
quantized_bestrq_emb_bgm[mask_indices] = 0
|
470 |
+
latents = latents.permute(0,2,1).contiguous()
|
471 |
+
latents = self.normfeat.project_sample(latents)
|
472 |
+
latents = latents.permute(0,2,1).contiguous()
|
473 |
+
incontext_latents = latents * ((latent_masks > 0.5) * (latent_masks < 1.5)).unsqueeze(-1).float()
|
474 |
+
attention_mask=(latent_masks > 0.5)
|
475 |
+
B, L = attention_mask.size()
|
476 |
+
attention_mask = attention_mask.view(B, 1, L)
|
477 |
+
attention_mask = attention_mask * attention_mask.transpose(-1, -2)
|
478 |
+
attention_mask = attention_mask.unsqueeze(1)
|
479 |
+
latent_mask_input = self.mask_emb(latent_masks)
|
480 |
+
loss,loss_re, loss_cos = self.cfm_wrapper.compute_loss(latents, [latent_mask_input,incontext_latents, quantized_bestrq_emb,quantized_bestrq_emb_bgm], latent_masks,attention_mask,wav2vec_embeds, validation_mode=validation_mode)
|
481 |
+
return loss,loss_re, loss_cos, commitment_loss.mean(), codebook_loss.mean()
|
482 |
+
|
483 |
+
def init_device_dtype(self, device, dtype):
|
484 |
+
self.device = device
|
485 |
+
self.dtype = dtype
|
486 |
+
|
487 |
+
@torch.no_grad()
|
488 |
+
def fetch_codes(self, input_audios_vocal,input_audios_bgm, additional_feats,layer_vocal=7,layer_bgm=7):
|
489 |
+
input_audio_vocal_0 = input_audios_vocal[[0],:]
|
490 |
+
input_audio_vocal_1 = input_audios_vocal[[1],:]
|
491 |
+
input_audio_vocal_0 = self.preprocess_audio(input_audio_vocal_0)
|
492 |
+
input_audio_vocal_1 = self.preprocess_audio(input_audio_vocal_1)
|
493 |
+
input_audios_vocal_wav2vec = (input_audio_vocal_0 + input_audio_vocal_1) / 2.0
|
494 |
+
|
495 |
+
input_audio_bgm_0 = input_audios_bgm[[0],:]
|
496 |
+
input_audio_bgm_1 = input_audios_bgm[[1],:]
|
497 |
+
input_audio_bgm_0 = self.preprocess_audio(input_audio_bgm_0)
|
498 |
+
input_audio_bgm_1 = self.preprocess_audio(input_audio_bgm_1)
|
499 |
+
input_audios_bgm_wav2vec = (input_audio_bgm_0 + input_audio_bgm_1) / 2.0
|
500 |
+
|
501 |
+
self.bestrq.eval()
|
502 |
+
|
503 |
+
# bestrq_middle,bestrq_last = self.extract_bestrq_embeds(input_audios)
|
504 |
+
# bestrq_middle = bestrq_middle.detach()
|
505 |
+
# bestrq_last = bestrq_last.detach()
|
506 |
+
bestrq_emb = self.extract_bestrq_embeds(input_audio_vocal_0,input_audio_vocal_1,layer_vocal)
|
507 |
+
bestrq_emb = bestrq_emb.detach()
|
508 |
+
|
509 |
+
bestrq_emb_bgm = self.extract_bestrq_embeds(input_audio_bgm_0,input_audio_bgm_1,layer_bgm)
|
510 |
+
bestrq_emb_bgm = bestrq_emb_bgm.detach()
|
511 |
+
|
512 |
+
|
513 |
+
|
514 |
+
self.rvq_bestrq_emb.eval()
|
515 |
+
quantized_bestrq_emb, codes_bestrq_emb, *_ = self.rvq_bestrq_emb(bestrq_emb) # b,d,t
|
516 |
+
|
517 |
+
self.rvq_bestrq_bgm_emb.eval()
|
518 |
+
quantized_bestrq_emb_bgm, codes_bestrq_emb_bgm, *_ = self.rvq_bestrq_bgm_emb(bestrq_emb_bgm) # b,d,t
|
519 |
+
|
520 |
+
|
521 |
+
if('spk' in additional_feats):
|
522 |
+
self.xvecmodel.eval()
|
523 |
+
spk_embeds = self.extract_spk_embeds(input_audios)
|
524 |
+
else:
|
525 |
+
spk_embeds = None
|
526 |
+
|
527 |
+
# return [codes_prompt, codes_wav2vec], [prompt_embeds, wav2vec_embeds], spk_embeds
|
528 |
+
# return [codes_prompt_7, codes_prompt_13, codes_prompt_20, codes_wav2vec_half, codes_wav2vec_last], [prompt_embeds_7, prompt_embeds_13, prompt_embeds_20, wav2vec_embeds_half, wav2vec_embeds_last], spk_embeds
|
529 |
+
# return [codes_bestrq_middle, codes_bestrq_last], [bestrq_middle, bestrq_last], spk_embeds
|
530 |
+
return [codes_bestrq_emb,codes_bestrq_emb_bgm], [bestrq_emb,bestrq_emb_bgm], spk_embeds
|
531 |
+
# return [codes_prompt_13, codes_wav2vec_last], [prompt_embeds_13, wav2vec_embeds_last], spk_embeds
|
532 |
+
|
533 |
+
@torch.no_grad()
|
534 |
+
def fetch_codes_batch(self, input_audios_vocal, input_audios_bgm, additional_feats,layer_vocal=7,layer_bgm=7):
|
535 |
+
input_audio_vocal_0 = input_audios_vocal[:,0,:]
|
536 |
+
input_audio_vocal_1 = input_audios_vocal[:,1,:]
|
537 |
+
input_audio_vocal_0 = self.preprocess_audio(input_audio_vocal_0)
|
538 |
+
input_audio_vocal_1 = self.preprocess_audio(input_audio_vocal_1)
|
539 |
+
input_audios_vocal_wav2vec = (input_audio_vocal_0 + input_audio_vocal_1) / 2.0
|
540 |
+
|
541 |
+
input_audio_bgm_0 = input_audios_bgm[:,0,:]
|
542 |
+
input_audio_bgm_1 = input_audios_bgm[:,1,:]
|
543 |
+
input_audio_bgm_0 = self.preprocess_audio(input_audio_bgm_0)
|
544 |
+
input_audio_bgm_1 = self.preprocess_audio(input_audio_bgm_1)
|
545 |
+
input_audios_bgm_wav2vec = (input_audio_bgm_0 + input_audio_bgm_1) / 2.0
|
546 |
+
|
547 |
+
self.bestrq.eval()
|
548 |
+
|
549 |
+
# bestrq_middle,bestrq_last = self.extract_bestrq_embeds(input_audios)
|
550 |
+
# bestrq_middle = bestrq_middle.detach()
|
551 |
+
# bestrq_last = bestrq_last.detach()
|
552 |
+
bestrq_emb = self.extract_bestrq_embeds(input_audio_vocal_0,input_audio_vocal_1,layer_vocal)
|
553 |
+
bestrq_emb = bestrq_emb.detach()
|
554 |
+
|
555 |
+
bestrq_emb_bgm = self.extract_bestrq_embeds(input_audio_bgm_0,input_audio_bgm_1,layer_bgm)
|
556 |
+
bestrq_emb_bgm = bestrq_emb_bgm.detach()
|
557 |
+
|
558 |
+
|
559 |
+
|
560 |
+
self.rvq_bestrq_emb.eval()
|
561 |
+
quantized_bestrq_emb, codes_bestrq_emb, *_ = self.rvq_bestrq_emb(bestrq_emb) # b,d,t
|
562 |
+
|
563 |
+
self.rvq_bestrq_bgm_emb.eval()
|
564 |
+
quantized_bestrq_emb_bgm, codes_bestrq_emb_bgm, *_ = self.rvq_bestrq_bgm_emb(bestrq_emb_bgm) # b,d,t
|
565 |
+
|
566 |
+
|
567 |
+
if('spk' in additional_feats):
|
568 |
+
self.xvecmodel.eval()
|
569 |
+
spk_embeds = self.extract_spk_embeds(input_audios)
|
570 |
+
else:
|
571 |
+
spk_embeds = None
|
572 |
+
|
573 |
+
# return [codes_prompt, codes_wav2vec], [prompt_embeds, wav2vec_embeds], spk_embeds
|
574 |
+
# return [codes_prompt_7, codes_prompt_13, codes_prompt_20, codes_wav2vec_half, codes_wav2vec_last], [prompt_embeds_7, prompt_embeds_13, prompt_embeds_20, wav2vec_embeds_half, wav2vec_embeds_last], spk_embeds
|
575 |
+
# return [codes_bestrq_middle, codes_bestrq_last], [bestrq_middle, bestrq_last], spk_embeds
|
576 |
+
return [codes_bestrq_emb,codes_bestrq_emb_bgm], [bestrq_emb,bestrq_emb_bgm], spk_embeds
|
577 |
+
# return [codes_prompt_13, codes_wav2vec_last], [prompt_embeds_13, wav2vec_embeds_last], spk_embeds
|
578 |
+
|
579 |
+
|
580 |
+
@torch.no_grad()
|
581 |
+
def inference_codes(self, codes, spk_embeds, true_latents, latent_length, additional_feats,incontext_length=127,
|
582 |
+
guidance_scale=2, num_steps=20,
|
583 |
+
disable_progress=True, scenario='start_seg'):
|
584 |
+
classifier_free_guidance = guidance_scale > 1.0
|
585 |
+
device = self.device
|
586 |
+
dtype = self.dtype
|
587 |
+
# codes_bestrq_middle, codes_bestrq_last = codes
|
588 |
+
codes_bestrq_emb,codes_bestrq_emb_bgm = codes
|
589 |
+
|
590 |
+
|
591 |
+
batch_size = codes_bestrq_emb.shape[0]
|
592 |
+
|
593 |
+
|
594 |
+
quantized_bestrq_emb,_,_=self.rvq_bestrq_emb.from_codes(codes_bestrq_emb)
|
595 |
+
quantized_bestrq_emb_bgm,_,_=self.rvq_bestrq_bgm_emb.from_codes(codes_bestrq_emb_bgm)
|
596 |
+
quantized_bestrq_emb = quantized_bestrq_emb.permute(0,2,1).contiguous()
|
597 |
+
quantized_bestrq_emb_bgm = quantized_bestrq_emb_bgm.permute(0,2,1).contiguous()
|
598 |
+
if('spk' in additional_feats):
|
599 |
+
spk_embeds = spk_embeds.repeat(1,1,quantized_bestrq_emb.shape[-2],1).detach()
|
600 |
+
|
601 |
+
num_frames = quantized_bestrq_emb.shape[1]
|
602 |
+
|
603 |
+
num_channels_latents = self.num_channels
|
604 |
+
shape = (batch_size, num_frames, 64)
|
605 |
+
latents = randn_tensor(shape, generator=None, device=device, dtype=dtype)
|
606 |
+
|
607 |
+
|
608 |
+
|
609 |
+
latent_masks = torch.zeros(latents.shape[0], latents.shape[1], dtype=torch.int64, device=latents.device)
|
610 |
+
latent_masks[:,0:latent_length] = 2
|
611 |
+
if(scenario=='other_seg'):
|
612 |
+
latent_masks[:,0:incontext_length] = 1
|
613 |
+
|
614 |
+
|
615 |
+
|
616 |
+
quantized_bestrq_emb = (latent_masks > 0.5).unsqueeze(-1) * quantized_bestrq_emb \
|
617 |
+
+ (latent_masks < 0.5).unsqueeze(-1) * self.zero_cond_embedding1.reshape(1,1,1024)
|
618 |
+
quantized_bestrq_emb_bgm = (latent_masks > 0.5).unsqueeze(-1) * quantized_bestrq_emb_bgm \
|
619 |
+
+ (latent_masks < 0.5).unsqueeze(-1) * self.zero_cond_embedding1.reshape(1,1,1024)
|
620 |
+
true_latents = true_latents.permute(0,2,1).contiguous()
|
621 |
+
true_latents = self.normfeat.project_sample(true_latents)
|
622 |
+
true_latents = true_latents.permute(0,2,1).contiguous()
|
623 |
+
incontext_latents = true_latents * ((latent_masks > 0.5) * (latent_masks < 1.5)).unsqueeze(-1).float()
|
624 |
+
incontext_length = ((latent_masks > 0.5) * (latent_masks < 1.5)).sum(-1)[0]
|
625 |
+
|
626 |
+
|
627 |
+
attention_mask=(latent_masks > 0.5)
|
628 |
+
B, L = attention_mask.size()
|
629 |
+
attention_mask = attention_mask.view(B, 1, L)
|
630 |
+
attention_mask = attention_mask * attention_mask.transpose(-1, -2)
|
631 |
+
attention_mask = attention_mask.unsqueeze(1)
|
632 |
+
latent_mask_input = self.mask_emb(latent_masks)
|
633 |
+
|
634 |
+
if('spk' in additional_feats):
|
635 |
+
# additional_model_input = torch.cat([quantized_bestrq_middle, quantized_bestrq_last, spk_embeds],1)
|
636 |
+
additional_model_input = torch.cat([quantized_bestrq_emb,quantized_bestrq_emb_bgm, spk_embeds],2)
|
637 |
+
else:
|
638 |
+
# additional_model_input = torch.cat([quantized_bestrq_middle, quantized_bestrq_last],1)
|
639 |
+
additional_model_input = torch.cat([quantized_bestrq_emb,quantized_bestrq_emb_bgm],2)
|
640 |
+
|
641 |
+
temperature = 1.0
|
642 |
+
t_span = torch.linspace(0, 1, num_steps + 1, device=quantized_bestrq_emb.device)
|
643 |
+
latents = self.cfm_wrapper.solve_euler(latents * temperature, latent_mask_input,incontext_latents, incontext_length, t_span, additional_model_input,attention_mask, guidance_scale)
|
644 |
+
|
645 |
+
latents[:,0:incontext_length,:] = incontext_latents[:,0:incontext_length,:]
|
646 |
+
latents = latents.permute(0,2,1).contiguous()
|
647 |
+
latents = self.normfeat.return_sample(latents)
|
648 |
+
# latents = latents.permute(0,2,1).contiguous()
|
649 |
+
return latents
|
650 |
+
|
651 |
+
@torch.no_grad()
|
652 |
+
def inference(self, input_audios_vocal,input_audios_bgm, lyric, true_latents, latent_length, additional_feats, guidance_scale=2, num_steps=20,
|
653 |
+
disable_progress=True,layer_vocal=7,layer_bgm=3,scenario='start_seg'):
|
654 |
+
codes, embeds, spk_embeds = self.fetch_codes(input_audios_vocal,input_audios_bgm, additional_feats,layer_vocal,layer_bgm)
|
655 |
+
|
656 |
+
latents = self.inference_codes(codes, spk_embeds, true_latents, latent_length, additional_feats, \
|
657 |
+
guidance_scale=guidance_scale, num_steps=num_steps, \
|
658 |
+
disable_progress=disable_progress,scenario=scenario)
|
659 |
+
return latents
|
660 |
+
|
661 |
+
def prepare_latents(self, batch_size, num_frames, num_channels_latents, dtype, device):
|
662 |
+
divisor = 4
|
663 |
+
shape = (batch_size, num_channels_latents, num_frames, 32)
|
664 |
+
if(num_frames%divisor>0):
|
665 |
+
num_frames = round(num_frames/float(divisor))*divisor
|
666 |
+
shape = (batch_size, num_channels_latents, num_frames, 32)
|
667 |
+
latents = randn_tensor(shape, generator=None, device=device, dtype=dtype)
|
668 |
+
return latents
|
669 |
+
|
670 |
+
|
codeclm/tokenizer/Flow1dVAE/models/unet_2d_condition_additionalemb.py
CHANGED
The diff for this file is too large to render.
See raw diff
|
|
codeclm/tokenizer/Flow1dVAE/models/unet_2d_condition_flow.py
CHANGED
The diff for this file is too large to render.
See raw diff
|
|
codeclm/tokenizer/Flow1dVAE/models_gpt/models/tokenizer/pinyin/symbols.py
CHANGED
@@ -1,71 +1,71 @@
|
|
1 |
-
_pause = ["sil", "eos", "sp", "#0", "#1", "#2", "#3"]
|
2 |
-
|
3 |
-
_initials = [
|
4 |
-
"^",
|
5 |
-
"b",
|
6 |
-
"c",
|
7 |
-
"ch",
|
8 |
-
"d",
|
9 |
-
"f",
|
10 |
-
"g",
|
11 |
-
"h",
|
12 |
-
"j",
|
13 |
-
"k",
|
14 |
-
"l",
|
15 |
-
"m",
|
16 |
-
"n",
|
17 |
-
"p",
|
18 |
-
"q",
|
19 |
-
"r",
|
20 |
-
"s",
|
21 |
-
"sh",
|
22 |
-
"t",
|
23 |
-
"x",
|
24 |
-
"z",
|
25 |
-
"zh",
|
26 |
-
]
|
27 |
-
|
28 |
-
_tones = ["1", "2", "3", "4", "5"]
|
29 |
-
|
30 |
-
_finals = [
|
31 |
-
"a",
|
32 |
-
"ai",
|
33 |
-
"an",
|
34 |
-
"ang",
|
35 |
-
"ao",
|
36 |
-
"e",
|
37 |
-
"ei",
|
38 |
-
"en",
|
39 |
-
"eng",
|
40 |
-
"er",
|
41 |
-
"i",
|
42 |
-
"ia",
|
43 |
-
"ian",
|
44 |
-
"iang",
|
45 |
-
"iao",
|
46 |
-
"ie",
|
47 |
-
"ii",
|
48 |
-
"iii",
|
49 |
-
"in",
|
50 |
-
"ing",
|
51 |
-
"iong",
|
52 |
-
"iou",
|
53 |
-
"o",
|
54 |
-
"ong",
|
55 |
-
"ou",
|
56 |
-
"u",
|
57 |
-
"ua",
|
58 |
-
"uai",
|
59 |
-
"uan",
|
60 |
-
"uang",
|
61 |
-
"uei",
|
62 |
-
"uen",
|
63 |
-
"ueng",
|
64 |
-
"uo",
|
65 |
-
"v",
|
66 |
-
"van",
|
67 |
-
"ve",
|
68 |
-
"vn",
|
69 |
-
]
|
70 |
-
|
71 |
-
symbols = _pause + _initials + [i + j for i in _finals for j in _tones]
|
|
|
1 |
+
_pause = ["sil", "eos", "sp", "#0", "#1", "#2", "#3"]
|
2 |
+
|
3 |
+
_initials = [
|
4 |
+
"^",
|
5 |
+
"b",
|
6 |
+
"c",
|
7 |
+
"ch",
|
8 |
+
"d",
|
9 |
+
"f",
|
10 |
+
"g",
|
11 |
+
"h",
|
12 |
+
"j",
|
13 |
+
"k",
|
14 |
+
"l",
|
15 |
+
"m",
|
16 |
+
"n",
|
17 |
+
"p",
|
18 |
+
"q",
|
19 |
+
"r",
|
20 |
+
"s",
|
21 |
+
"sh",
|
22 |
+
"t",
|
23 |
+
"x",
|
24 |
+
"z",
|
25 |
+
"zh",
|
26 |
+
]
|
27 |
+
|
28 |
+
_tones = ["1", "2", "3", "4", "5"]
|
29 |
+
|
30 |
+
_finals = [
|
31 |
+
"a",
|
32 |
+
"ai",
|
33 |
+
"an",
|
34 |
+
"ang",
|
35 |
+
"ao",
|
36 |
+
"e",
|
37 |
+
"ei",
|
38 |
+
"en",
|
39 |
+
"eng",
|
40 |
+
"er",
|
41 |
+
"i",
|
42 |
+
"ia",
|
43 |
+
"ian",
|
44 |
+
"iang",
|
45 |
+
"iao",
|
46 |
+
"ie",
|
47 |
+
"ii",
|
48 |
+
"iii",
|
49 |
+
"in",
|
50 |
+
"ing",
|
51 |
+
"iong",
|
52 |
+
"iou",
|
53 |
+
"o",
|
54 |
+
"ong",
|
55 |
+
"ou",
|
56 |
+
"u",
|
57 |
+
"ua",
|
58 |
+
"uai",
|
59 |
+
"uan",
|
60 |
+
"uang",
|
61 |
+
"uei",
|
62 |
+
"uen",
|
63 |
+
"ueng",
|
64 |
+
"uo",
|
65 |
+
"v",
|
66 |
+
"van",
|
67 |
+
"ve",
|
68 |
+
"vn",
|
69 |
+
]
|
70 |
+
|
71 |
+
symbols = _pause + _initials + [i + j for i in _finals for j in _tones]
|
codeclm/tokenizer/Flow1dVAE/tools/infer_bsrnnvae441k.py
CHANGED
@@ -1,47 +1,47 @@
|
|
1 |
-
import json
|
2 |
-
import torch
|
3 |
-
from tqdm import tqdm
|
4 |
-
import torchaudio
|
5 |
-
import librosa
|
6 |
-
import os
|
7 |
-
import math
|
8 |
-
import numpy as np
|
9 |
-
from tools.get_bsrnnvae import get_bsrnnvae
|
10 |
-
import tools.torch_tools as torch_tools
|
11 |
-
|
12 |
-
class Tango:
|
13 |
-
def __init__(self, \
|
14 |
-
device="cuda:0"):
|
15 |
-
|
16 |
-
self.sample_rate = 44100
|
17 |
-
self.device = device
|
18 |
-
|
19 |
-
self.vae = get_bsrnnvae()
|
20 |
-
self.vae = self.vae.eval().to(device)
|
21 |
-
|
22 |
-
def sound2sound_generate_longterm(self, fname, batch_size=1, duration=15.36, steps=200, disable_progress=False):
|
23 |
-
""" Genrate audio without condition. """
|
24 |
-
num_frames = math.ceil(duration * 100. / 8)
|
25 |
-
with torch.no_grad():
|
26 |
-
orig_samples, fs = torchaudio.load(fname)
|
27 |
-
if(fs!=44100):
|
28 |
-
orig_samples = torchaudio.functional.resample(orig_samples, fs, 44100)
|
29 |
-
fs = 44100
|
30 |
-
if(orig_samples.shape[-1]<int(duration*44100*2)):
|
31 |
-
orig_samples = torch.cat([orig_samples, torch.zeros(orig_samples.shape[0], int(duration*44100*2+480)-orig_samples.shape[-1], \
|
32 |
-
dtype=orig_samples.dtype, device=orig_samples.device)], -1)
|
33 |
-
# orig_samples = torch.cat([torch.zeros(orig_samples.shape[0], int(duration * fs)//2, dtype=orig_samples.dtype, device=orig_samples.device), orig_samples, torch.zeros(orig_samples.shape[0], int(duration * fs)//2, dtype=orig_samples.dtype, device=orig_samples.device)], -1).to(self.device)
|
34 |
-
orig_samples = torch.cat([orig_samples, torch.zeros(orig_samples.shape[0], int(duration * fs)//2, dtype=orig_samples.dtype, device=orig_samples.device)], -1).to(self.device)
|
35 |
-
if(fs!=44100):orig_samples = torchaudio.functional.resample(orig_samples, fs, 44100)
|
36 |
-
# resampled_audios = orig_samples[[0],int(4.64*44100):int(35.36*48000)+480].clamp(-1,1)
|
37 |
-
resampled_audios = orig_samples[[0],0:int(duration*2*44100)+480].clamp(-1,1)
|
38 |
-
orig_samples = orig_samples[[0],0:int(duration*2*44100)]
|
39 |
-
|
40 |
-
audio = self.vae(orig_samples[:,None,:])[:,0,:]
|
41 |
-
|
42 |
-
if(orig_samples.shape[-1]<audio.shape[-1]):
|
43 |
-
orig_samples = torch.cat([orig_samples, torch.zeros(orig_samples.shape[0], audio.shape[-1]-orig_samples.shape[-1], dtype=orig_samples.dtype, device=orig_samples.device)],-1)
|
44 |
-
else:
|
45 |
-
orig_samples = orig_samples[:,0:audio.shape[-1]]
|
46 |
-
output = torch.cat([orig_samples.detach().cpu(),audio.detach().cpu()],0)
|
47 |
-
return output
|
|
|
1 |
+
import json
|
2 |
+
import torch
|
3 |
+
from tqdm import tqdm
|
4 |
+
import torchaudio
|
5 |
+
import librosa
|
6 |
+
import os
|
7 |
+
import math
|
8 |
+
import numpy as np
|
9 |
+
from tools.get_bsrnnvae import get_bsrnnvae
|
10 |
+
import tools.torch_tools as torch_tools
|
11 |
+
|
12 |
+
class Tango:
|
13 |
+
def __init__(self, \
|
14 |
+
device="cuda:0"):
|
15 |
+
|
16 |
+
self.sample_rate = 44100
|
17 |
+
self.device = device
|
18 |
+
|
19 |
+
self.vae = get_bsrnnvae()
|
20 |
+
self.vae = self.vae.eval().to(device)
|
21 |
+
|
22 |
+
def sound2sound_generate_longterm(self, fname, batch_size=1, duration=15.36, steps=200, disable_progress=False):
|
23 |
+
""" Genrate audio without condition. """
|
24 |
+
num_frames = math.ceil(duration * 100. / 8)
|
25 |
+
with torch.no_grad():
|
26 |
+
orig_samples, fs = torchaudio.load(fname)
|
27 |
+
if(fs!=44100):
|
28 |
+
orig_samples = torchaudio.functional.resample(orig_samples, fs, 44100)
|
29 |
+
fs = 44100
|
30 |
+
if(orig_samples.shape[-1]<int(duration*44100*2)):
|
31 |
+
orig_samples = torch.cat([orig_samples, torch.zeros(orig_samples.shape[0], int(duration*44100*2+480)-orig_samples.shape[-1], \
|
32 |
+
dtype=orig_samples.dtype, device=orig_samples.device)], -1)
|
33 |
+
# orig_samples = torch.cat([torch.zeros(orig_samples.shape[0], int(duration * fs)//2, dtype=orig_samples.dtype, device=orig_samples.device), orig_samples, torch.zeros(orig_samples.shape[0], int(duration * fs)//2, dtype=orig_samples.dtype, device=orig_samples.device)], -1).to(self.device)
|
34 |
+
orig_samples = torch.cat([orig_samples, torch.zeros(orig_samples.shape[0], int(duration * fs)//2, dtype=orig_samples.dtype, device=orig_samples.device)], -1).to(self.device)
|
35 |
+
if(fs!=44100):orig_samples = torchaudio.functional.resample(orig_samples, fs, 44100)
|
36 |
+
# resampled_audios = orig_samples[[0],int(4.64*44100):int(35.36*48000)+480].clamp(-1,1)
|
37 |
+
resampled_audios = orig_samples[[0],0:int(duration*2*44100)+480].clamp(-1,1)
|
38 |
+
orig_samples = orig_samples[[0],0:int(duration*2*44100)]
|
39 |
+
|
40 |
+
audio = self.vae(orig_samples[:,None,:])[:,0,:]
|
41 |
+
|
42 |
+
if(orig_samples.shape[-1]<audio.shape[-1]):
|
43 |
+
orig_samples = torch.cat([orig_samples, torch.zeros(orig_samples.shape[0], audio.shape[-1]-orig_samples.shape[-1], dtype=orig_samples.dtype, device=orig_samples.device)],-1)
|
44 |
+
else:
|
45 |
+
orig_samples = orig_samples[:,0:audio.shape[-1]]
|
46 |
+
output = torch.cat([orig_samples.detach().cpu(),audio.detach().cpu()],0)
|
47 |
+
return output
|
codeclm/tokenizer/Flow1dVAE/tools/infer_bsrnnvae441k_vocal.py
CHANGED
@@ -1,47 +1,47 @@
|
|
1 |
-
import json
|
2 |
-
import torch
|
3 |
-
from tqdm import tqdm
|
4 |
-
import torchaudio
|
5 |
-
import librosa
|
6 |
-
import os
|
7 |
-
import math
|
8 |
-
import numpy as np
|
9 |
-
from tools.get_bsrnnvae import get_bsrnnvae
|
10 |
-
import tools.torch_tools as torch_tools
|
11 |
-
|
12 |
-
class Tango:
|
13 |
-
def __init__(self, \
|
14 |
-
device="cuda:0"):
|
15 |
-
|
16 |
-
self.sample_rate = 44100
|
17 |
-
self.device = device
|
18 |
-
|
19 |
-
self.vae = get_bsrnnvae()
|
20 |
-
self.vae = self.vae.eval().to(device)
|
21 |
-
|
22 |
-
def sound2sound_generate_longterm(self, fname, batch_size=1, duration=20.48, steps=200, disable_progress=False):
|
23 |
-
""" Genrate audio without condition. """
|
24 |
-
num_frames = math.ceil(duration * 100. / 8)
|
25 |
-
with torch.no_grad():
|
26 |
-
orig_samples, fs = torchaudio.load(fname)
|
27 |
-
if(fs!=44100):
|
28 |
-
orig_samples = torchaudio.functional.resample(orig_samples, fs, 44100)
|
29 |
-
fs = 44100
|
30 |
-
if(orig_samples.shape[-1]<int(duration*44100*2)):
|
31 |
-
orig_samples = torch.cat([orig_samples, torch.zeros(orig_samples.shape[0], int(duration*44100*2+480)-orig_samples.shape[-1], \
|
32 |
-
dtype=orig_samples.dtype, device=orig_samples.device)], -1)
|
33 |
-
# orig_samples = torch.cat([torch.zeros(orig_samples.shape[0], int(duration * fs)//2, dtype=orig_samples.dtype, device=orig_samples.device), orig_samples, torch.zeros(orig_samples.shape[0], int(duration * fs)//2, dtype=orig_samples.dtype, device=orig_samples.device)], -1).to(self.device)
|
34 |
-
orig_samples = torch.cat([orig_samples, torch.zeros(orig_samples.shape[0], int(duration * fs)//2, dtype=orig_samples.dtype, device=orig_samples.device)], -1).to(self.device)
|
35 |
-
if(fs!=44100):orig_samples = torchaudio.functional.resample(orig_samples, fs, 44100)
|
36 |
-
# resampled_audios = orig_samples[[0],int(4.64*44100):int(35.36*48000)+480].clamp(-1,1)
|
37 |
-
resampled_audios = orig_samples[[0],0:int(duration*2*44100)+480].clamp(-1,1)
|
38 |
-
orig_samples = orig_samples[[0],0:int(duration*2*44100)]
|
39 |
-
|
40 |
-
audio = self.vae(orig_samples[:,None,:])[:,0,:]
|
41 |
-
|
42 |
-
if(orig_samples.shape[-1]<audio.shape[-1]):
|
43 |
-
orig_samples = torch.cat([orig_samples, torch.zeros(orig_samples.shape[0], audio.shape[-1]-orig_samples.shape[-1], dtype=orig_samples.dtype, device=orig_samples.device)],-1)
|
44 |
-
else:
|
45 |
-
orig_samples = orig_samples[:,0:audio.shape[-1]]
|
46 |
-
output = torch.cat([orig_samples.detach().cpu(),audio.detach().cpu()],0)
|
47 |
-
return output
|
|
|
1 |
+
import json
|
2 |
+
import torch
|
3 |
+
from tqdm import tqdm
|
4 |
+
import torchaudio
|
5 |
+
import librosa
|
6 |
+
import os
|
7 |
+
import math
|
8 |
+
import numpy as np
|
9 |
+
from tools.get_bsrnnvae import get_bsrnnvae
|
10 |
+
import tools.torch_tools as torch_tools
|
11 |
+
|
12 |
+
class Tango:
|
13 |
+
def __init__(self, \
|
14 |
+
device="cuda:0"):
|
15 |
+
|
16 |
+
self.sample_rate = 44100
|
17 |
+
self.device = device
|
18 |
+
|
19 |
+
self.vae = get_bsrnnvae()
|
20 |
+
self.vae = self.vae.eval().to(device)
|
21 |
+
|
22 |
+
def sound2sound_generate_longterm(self, fname, batch_size=1, duration=20.48, steps=200, disable_progress=False):
|
23 |
+
""" Genrate audio without condition. """
|
24 |
+
num_frames = math.ceil(duration * 100. / 8)
|
25 |
+
with torch.no_grad():
|
26 |
+
orig_samples, fs = torchaudio.load(fname)
|
27 |
+
if(fs!=44100):
|
28 |
+
orig_samples = torchaudio.functional.resample(orig_samples, fs, 44100)
|
29 |
+
fs = 44100
|
30 |
+
if(orig_samples.shape[-1]<int(duration*44100*2)):
|
31 |
+
orig_samples = torch.cat([orig_samples, torch.zeros(orig_samples.shape[0], int(duration*44100*2+480)-orig_samples.shape[-1], \
|
32 |
+
dtype=orig_samples.dtype, device=orig_samples.device)], -1)
|
33 |
+
# orig_samples = torch.cat([torch.zeros(orig_samples.shape[0], int(duration * fs)//2, dtype=orig_samples.dtype, device=orig_samples.device), orig_samples, torch.zeros(orig_samples.shape[0], int(duration * fs)//2, dtype=orig_samples.dtype, device=orig_samples.device)], -1).to(self.device)
|
34 |
+
orig_samples = torch.cat([orig_samples, torch.zeros(orig_samples.shape[0], int(duration * fs)//2, dtype=orig_samples.dtype, device=orig_samples.device)], -1).to(self.device)
|
35 |
+
if(fs!=44100):orig_samples = torchaudio.functional.resample(orig_samples, fs, 44100)
|
36 |
+
# resampled_audios = orig_samples[[0],int(4.64*44100):int(35.36*48000)+480].clamp(-1,1)
|
37 |
+
resampled_audios = orig_samples[[0],0:int(duration*2*44100)+480].clamp(-1,1)
|
38 |
+
orig_samples = orig_samples[[0],0:int(duration*2*44100)]
|
39 |
+
|
40 |
+
audio = self.vae(orig_samples[:,None,:])[:,0,:]
|
41 |
+
|
42 |
+
if(orig_samples.shape[-1]<audio.shape[-1]):
|
43 |
+
orig_samples = torch.cat([orig_samples, torch.zeros(orig_samples.shape[0], audio.shape[-1]-orig_samples.shape[-1], dtype=orig_samples.dtype, device=orig_samples.device)],-1)
|
44 |
+
else:
|
45 |
+
orig_samples = orig_samples[:,0:audio.shape[-1]]
|
46 |
+
output = torch.cat([orig_samples.detach().cpu(),audio.detach().cpu()],0)
|
47 |
+
return output
|
codeclm/tokenizer/Flow1dVAE/tools/infer_hifigan48k_speech.py
CHANGED
@@ -1,56 +1,56 @@
|
|
1 |
-
import json
|
2 |
-
import torch
|
3 |
-
from tqdm import tqdm
|
4 |
-
import torchaudio
|
5 |
-
import librosa
|
6 |
-
import os
|
7 |
-
import math
|
8 |
-
import numpy as np
|
9 |
-
from get_melvaehifigan48k import build_pretrained_models
|
10 |
-
import tools.torch_tools as torch_tools
|
11 |
-
|
12 |
-
class Tango:
|
13 |
-
def __init__(self, \
|
14 |
-
device="cuda:0"):
|
15 |
-
|
16 |
-
self.sample_rate = 48000
|
17 |
-
self.device = device
|
18 |
-
|
19 |
-
self.vae, self.stft = build_pretrained_models()
|
20 |
-
self.vae, self.stft = self.vae.eval().to(device), self.stft.eval().to(device)
|
21 |
-
|
22 |
-
def mel_spectrogram_to_waveform(self, mel_spectrogram):
|
23 |
-
if mel_spectrogram.dim() == 4:
|
24 |
-
mel_spectrogram = mel_spectrogram.squeeze(1)
|
25 |
-
|
26 |
-
waveform = self.vocoder(mel_spectrogram)
|
27 |
-
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
28 |
-
waveform = waveform.cpu().float()
|
29 |
-
return waveform
|
30 |
-
|
31 |
-
def sound2sound_generate_longterm(self, fname, batch_size=1, duration=10.24, steps=200, disable_progress=False):
|
32 |
-
""" Genrate audio without condition. """
|
33 |
-
num_frames = math.ceil(duration * 100. / 8)
|
34 |
-
with torch.no_grad():
|
35 |
-
orig_samples, fs = torchaudio.load(fname)
|
36 |
-
if(orig_samples.shape[-1]<int(duration*48000)):
|
37 |
-
orig_samples = orig_samples.repeat(1,math.ceil(int(duration*48000)/float(orig_samples.shape[-1])))
|
38 |
-
# orig_samples = torch.cat([torch.zeros(orig_samples.shape[0], int(duration * fs)//2, dtype=orig_samples.dtype, device=orig_samples.device), orig_samples, torch.zeros(orig_samples.shape[0], int(duration * fs)//2, dtype=orig_samples.dtype, device=orig_samples.device)], -1).to(self.device)
|
39 |
-
orig_samples = torch.cat([orig_samples, torch.zeros(orig_samples.shape[0], int(duration * fs)//2, dtype=orig_samples.dtype, device=orig_samples.device)], -1).to(self.device)
|
40 |
-
if(fs!=48000):orig_samples = torchaudio.functional.resample(orig_samples, fs, 48000)
|
41 |
-
# resampled_audios = orig_samples[[0],int(4.64*48000):int(35.36*48000)+480].clamp(-1,1)
|
42 |
-
resampled_audios = orig_samples[[0],0:int(duration*48000)+480].clamp(-1,1)
|
43 |
-
orig_samples = orig_samples[[0],0:int(duration*48000)]
|
44 |
-
|
45 |
-
mel, _, _ = torch_tools.wav_to_fbank2(resampled_audios, -1, fn_STFT=self.stft)
|
46 |
-
mel = mel.unsqueeze(1).to(self.device)
|
47 |
-
|
48 |
-
audio = self.vae.decode_to_waveform(mel)
|
49 |
-
audio = torch.from_numpy(audio)
|
50 |
-
|
51 |
-
if(orig_samples.shape[-1]<audio.shape[-1]):
|
52 |
-
orig_samples = torch.cat([orig_samples, torch.zeros(orig_samples.shape[0], audio.shape[-1]-orig_samples.shape[-1], dtype=orig_samples.dtype, device=orig_samples.device)],-1)
|
53 |
-
else:
|
54 |
-
orig_samples = orig_samples[:,0:audio.shape[-1]]
|
55 |
-
output = torch.cat([orig_samples.detach().cpu(),audio.detach().cpu()],0)
|
56 |
-
return output
|
|
|
1 |
+
import json
|
2 |
+
import torch
|
3 |
+
from tqdm import tqdm
|
4 |
+
import torchaudio
|
5 |
+
import librosa
|
6 |
+
import os
|
7 |
+
import math
|
8 |
+
import numpy as np
|
9 |
+
from get_melvaehifigan48k import build_pretrained_models
|
10 |
+
import tools.torch_tools as torch_tools
|
11 |
+
|
12 |
+
class Tango:
|
13 |
+
def __init__(self, \
|
14 |
+
device="cuda:0"):
|
15 |
+
|
16 |
+
self.sample_rate = 48000
|
17 |
+
self.device = device
|
18 |
+
|
19 |
+
self.vae, self.stft = build_pretrained_models()
|
20 |
+
self.vae, self.stft = self.vae.eval().to(device), self.stft.eval().to(device)
|
21 |
+
|
22 |
+
def mel_spectrogram_to_waveform(self, mel_spectrogram):
|
23 |
+
if mel_spectrogram.dim() == 4:
|
24 |
+
mel_spectrogram = mel_spectrogram.squeeze(1)
|
25 |
+
|
26 |
+
waveform = self.vocoder(mel_spectrogram)
|
27 |
+
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
28 |
+
waveform = waveform.cpu().float()
|
29 |
+
return waveform
|
30 |
+
|
31 |
+
def sound2sound_generate_longterm(self, fname, batch_size=1, duration=10.24, steps=200, disable_progress=False):
|
32 |
+
""" Genrate audio without condition. """
|
33 |
+
num_frames = math.ceil(duration * 100. / 8)
|
34 |
+
with torch.no_grad():
|
35 |
+
orig_samples, fs = torchaudio.load(fname)
|
36 |
+
if(orig_samples.shape[-1]<int(duration*48000)):
|
37 |
+
orig_samples = orig_samples.repeat(1,math.ceil(int(duration*48000)/float(orig_samples.shape[-1])))
|
38 |
+
# orig_samples = torch.cat([torch.zeros(orig_samples.shape[0], int(duration * fs)//2, dtype=orig_samples.dtype, device=orig_samples.device), orig_samples, torch.zeros(orig_samples.shape[0], int(duration * fs)//2, dtype=orig_samples.dtype, device=orig_samples.device)], -1).to(self.device)
|
39 |
+
orig_samples = torch.cat([orig_samples, torch.zeros(orig_samples.shape[0], int(duration * fs)//2, dtype=orig_samples.dtype, device=orig_samples.device)], -1).to(self.device)
|
40 |
+
if(fs!=48000):orig_samples = torchaudio.functional.resample(orig_samples, fs, 48000)
|
41 |
+
# resampled_audios = orig_samples[[0],int(4.64*48000):int(35.36*48000)+480].clamp(-1,1)
|
42 |
+
resampled_audios = orig_samples[[0],0:int(duration*48000)+480].clamp(-1,1)
|
43 |
+
orig_samples = orig_samples[[0],0:int(duration*48000)]
|
44 |
+
|
45 |
+
mel, _, _ = torch_tools.wav_to_fbank2(resampled_audios, -1, fn_STFT=self.stft)
|
46 |
+
mel = mel.unsqueeze(1).to(self.device)
|
47 |
+
|
48 |
+
audio = self.vae.decode_to_waveform(mel)
|
49 |
+
audio = torch.from_numpy(audio)
|
50 |
+
|
51 |
+
if(orig_samples.shape[-1]<audio.shape[-1]):
|
52 |
+
orig_samples = torch.cat([orig_samples, torch.zeros(orig_samples.shape[0], audio.shape[-1]-orig_samples.shape[-1], dtype=orig_samples.dtype, device=orig_samples.device)],-1)
|
53 |
+
else:
|
54 |
+
orig_samples = orig_samples[:,0:audio.shape[-1]]
|
55 |
+
output = torch.cat([orig_samples.detach().cpu(),audio.detach().cpu()],0)
|
56 |
+
return output
|
codeclm/tokenizer/Flow1dVAE/tools/infer_hifigan48k_vocal.py
CHANGED
@@ -1,57 +1,57 @@
|
|
1 |
-
import json
|
2 |
-
import torch
|
3 |
-
from tqdm import tqdm
|
4 |
-
import torchaudio
|
5 |
-
import librosa
|
6 |
-
import os
|
7 |
-
import math
|
8 |
-
import numpy as np
|
9 |
-
from get_melvaehifigan48k import build_pretrained_models
|
10 |
-
import tools.torch_tools as torch_tools
|
11 |
-
|
12 |
-
class Tango:
|
13 |
-
def __init__(self, \
|
14 |
-
device="cuda:0"):
|
15 |
-
|
16 |
-
self.sample_rate = 48000
|
17 |
-
self.device = device
|
18 |
-
|
19 |
-
self.vae, self.stft = build_pretrained_models()
|
20 |
-
self.vae, self.stft = self.vae.eval().to(device), self.stft.eval().to(device)
|
21 |
-
|
22 |
-
def mel_spectrogram_to_waveform(self, mel_spectrogram):
|
23 |
-
if mel_spectrogram.dim() == 4:
|
24 |
-
mel_spectrogram = mel_spectrogram.squeeze(1)
|
25 |
-
|
26 |
-
waveform = self.vocoder(mel_spectrogram)
|
27 |
-
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
28 |
-
waveform = waveform.cpu().float()
|
29 |
-
return waveform
|
30 |
-
|
31 |
-
def sound2sound_generate_longterm(self, fname, batch_size=1, duration=10.24, steps=200, disable_progress=False):
|
32 |
-
""" Genrate audio without condition. """
|
33 |
-
num_frames = math.ceil(duration * 100. / 8)
|
34 |
-
with torch.no_grad():
|
35 |
-
orig_samples, fs = torchaudio.load(fname)
|
36 |
-
if(orig_samples.shape[-1]<int(duration*48000*2)):
|
37 |
-
orig_samples = torch.cat([orig_samples, torch.zeros(orig_samples.shape[0], int(duration*48000*2+480)-orig_samples.shape[-1], \
|
38 |
-
dtype=orig_samples.dtype, device=orig_samples.device)], -1)
|
39 |
-
# orig_samples = torch.cat([torch.zeros(orig_samples.shape[0], int(duration * fs)//2, dtype=orig_samples.dtype, device=orig_samples.device), orig_samples, torch.zeros(orig_samples.shape[0], int(duration * fs)//2, dtype=orig_samples.dtype, device=orig_samples.device)], -1).to(self.device)
|
40 |
-
orig_samples = torch.cat([orig_samples, torch.zeros(orig_samples.shape[0], int(duration * fs)//2, dtype=orig_samples.dtype, device=orig_samples.device)], -1).to(self.device)
|
41 |
-
if(fs!=48000):orig_samples = torchaudio.functional.resample(orig_samples, fs, 48000)
|
42 |
-
# resampled_audios = orig_samples[[0],int(4.64*48000):int(35.36*48000)+480].clamp(-1,1)
|
43 |
-
resampled_audios = orig_samples[[0],0:int(duration*2*48000)+480].clamp(-1,1)
|
44 |
-
orig_samples = orig_samples[[0],0:int(duration*2*48000)]
|
45 |
-
|
46 |
-
mel, _, _ = torch_tools.wav_to_fbank2(resampled_audios, -1, fn_STFT=self.stft)
|
47 |
-
mel = mel.unsqueeze(1).to(self.device)
|
48 |
-
|
49 |
-
audio = self.vae.decode_to_waveform(mel)
|
50 |
-
audio = torch.from_numpy(audio)
|
51 |
-
|
52 |
-
if(orig_samples.shape[-1]<audio.shape[-1]):
|
53 |
-
orig_samples = torch.cat([orig_samples, torch.zeros(orig_samples.shape[0], audio.shape[-1]-orig_samples.shape[-1], dtype=orig_samples.dtype, device=orig_samples.device)],-1)
|
54 |
-
else:
|
55 |
-
orig_samples = orig_samples[:,0:audio.shape[-1]]
|
56 |
-
output = torch.cat([orig_samples.detach().cpu(),audio.detach().cpu()],0)
|
57 |
-
return output
|
|
|
1 |
+
import json
|
2 |
+
import torch
|
3 |
+
from tqdm import tqdm
|
4 |
+
import torchaudio
|
5 |
+
import librosa
|
6 |
+
import os
|
7 |
+
import math
|
8 |
+
import numpy as np
|
9 |
+
from get_melvaehifigan48k import build_pretrained_models
|
10 |
+
import tools.torch_tools as torch_tools
|
11 |
+
|
12 |
+
class Tango:
|
13 |
+
def __init__(self, \
|
14 |
+
device="cuda:0"):
|
15 |
+
|
16 |
+
self.sample_rate = 48000
|
17 |
+
self.device = device
|
18 |
+
|
19 |
+
self.vae, self.stft = build_pretrained_models()
|
20 |
+
self.vae, self.stft = self.vae.eval().to(device), self.stft.eval().to(device)
|
21 |
+
|
22 |
+
def mel_spectrogram_to_waveform(self, mel_spectrogram):
|
23 |
+
if mel_spectrogram.dim() == 4:
|
24 |
+
mel_spectrogram = mel_spectrogram.squeeze(1)
|
25 |
+
|
26 |
+
waveform = self.vocoder(mel_spectrogram)
|
27 |
+
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
28 |
+
waveform = waveform.cpu().float()
|
29 |
+
return waveform
|
30 |
+
|
31 |
+
def sound2sound_generate_longterm(self, fname, batch_size=1, duration=10.24, steps=200, disable_progress=False):
|
32 |
+
""" Genrate audio without condition. """
|
33 |
+
num_frames = math.ceil(duration * 100. / 8)
|
34 |
+
with torch.no_grad():
|
35 |
+
orig_samples, fs = torchaudio.load(fname)
|
36 |
+
if(orig_samples.shape[-1]<int(duration*48000*2)):
|
37 |
+
orig_samples = torch.cat([orig_samples, torch.zeros(orig_samples.shape[0], int(duration*48000*2+480)-orig_samples.shape[-1], \
|
38 |
+
dtype=orig_samples.dtype, device=orig_samples.device)], -1)
|
39 |
+
# orig_samples = torch.cat([torch.zeros(orig_samples.shape[0], int(duration * fs)//2, dtype=orig_samples.dtype, device=orig_samples.device), orig_samples, torch.zeros(orig_samples.shape[0], int(duration * fs)//2, dtype=orig_samples.dtype, device=orig_samples.device)], -1).to(self.device)
|
40 |
+
orig_samples = torch.cat([orig_samples, torch.zeros(orig_samples.shape[0], int(duration * fs)//2, dtype=orig_samples.dtype, device=orig_samples.device)], -1).to(self.device)
|
41 |
+
if(fs!=48000):orig_samples = torchaudio.functional.resample(orig_samples, fs, 48000)
|
42 |
+
# resampled_audios = orig_samples[[0],int(4.64*48000):int(35.36*48000)+480].clamp(-1,1)
|
43 |
+
resampled_audios = orig_samples[[0],0:int(duration*2*48000)+480].clamp(-1,1)
|
44 |
+
orig_samples = orig_samples[[0],0:int(duration*2*48000)]
|
45 |
+
|
46 |
+
mel, _, _ = torch_tools.wav_to_fbank2(resampled_audios, -1, fn_STFT=self.stft)
|
47 |
+
mel = mel.unsqueeze(1).to(self.device)
|
48 |
+
|
49 |
+
audio = self.vae.decode_to_waveform(mel)
|
50 |
+
audio = torch.from_numpy(audio)
|
51 |
+
|
52 |
+
if(orig_samples.shape[-1]<audio.shape[-1]):
|
53 |
+
orig_samples = torch.cat([orig_samples, torch.zeros(orig_samples.shape[0], audio.shape[-1]-orig_samples.shape[-1], dtype=orig_samples.dtype, device=orig_samples.device)],-1)
|
54 |
+
else:
|
55 |
+
orig_samples = orig_samples[:,0:audio.shape[-1]]
|
56 |
+
output = torch.cat([orig_samples.detach().cpu(),audio.detach().cpu()],0)
|
57 |
+
return output
|
codeclm/tokenizer/Flow1dVAE/tools/infer_vaehifigan48k.py
CHANGED
@@ -1,59 +1,59 @@
|
|
1 |
-
import json
|
2 |
-
import torch
|
3 |
-
from tqdm import tqdm
|
4 |
-
import torchaudio
|
5 |
-
import librosa
|
6 |
-
import os
|
7 |
-
import math
|
8 |
-
import numpy as np
|
9 |
-
from get_melvaehifigan48k import build_pretrained_models
|
10 |
-
import tools.torch_tools as torch_tools
|
11 |
-
|
12 |
-
class Tango:
|
13 |
-
def __init__(self, \
|
14 |
-
device="cuda:0"):
|
15 |
-
|
16 |
-
self.sample_rate = 48000
|
17 |
-
self.device = device
|
18 |
-
|
19 |
-
self.vae, self.stft = build_pretrained_models()
|
20 |
-
self.vae, self.stft = self.vae.eval().to(device), self.stft.eval().to(device)
|
21 |
-
|
22 |
-
def mel_spectrogram_to_waveform(self, mel_spectrogram):
|
23 |
-
if mel_spectrogram.dim() == 4:
|
24 |
-
mel_spectrogram = mel_spectrogram.squeeze(1)
|
25 |
-
|
26 |
-
waveform = self.vocoder(mel_spectrogram)
|
27 |
-
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
28 |
-
waveform = waveform.cpu().float()
|
29 |
-
return waveform
|
30 |
-
|
31 |
-
def sound2sound_generate_longterm(self, fname, batch_size=1, duration=10.24, steps=200, disable_progress=False):
|
32 |
-
""" Genrate audio without condition. """
|
33 |
-
num_frames = math.ceil(duration * 100. / 8)
|
34 |
-
with torch.no_grad():
|
35 |
-
orig_samples, fs = torchaudio.load(fname)
|
36 |
-
if(orig_samples.shape[-1]<int(duration*48000*3)):
|
37 |
-
orig_samples = orig_samples.repeat(1,math.ceil(int(duration*48000*3)/float(orig_samples.shape[-1])))
|
38 |
-
# orig_samples = torch.cat([torch.zeros(orig_samples.shape[0], int(duration * fs)//2, dtype=orig_samples.dtype, device=orig_samples.device), orig_samples, torch.zeros(orig_samples.shape[0], int(duration * fs)//2, dtype=orig_samples.dtype, device=orig_samples.device)], -1).to(self.device)
|
39 |
-
orig_samples = torch.cat([orig_samples, torch.zeros(orig_samples.shape[0], int(duration * fs)//2, dtype=orig_samples.dtype, device=orig_samples.device)], -1).to(self.device)
|
40 |
-
if(fs!=48000):orig_samples = torchaudio.functional.resample(orig_samples, fs, 48000)
|
41 |
-
# resampled_audios = orig_samples[[0],int(4.64*48000):int(35.36*48000)+480].clamp(-1,1)
|
42 |
-
resampled_audios = orig_samples[[0],int(0*48000):int(duration*3*48000)+480].clamp(-1,1)
|
43 |
-
orig_samples = orig_samples[[0],:]
|
44 |
-
|
45 |
-
mel, _, _ = torch_tools.wav_to_fbank2(resampled_audios, -1, fn_STFT=self.stft)
|
46 |
-
mel = mel.unsqueeze(1).to(self.device)
|
47 |
-
latents = torch.cat([self.vae.get_first_stage_encoding(self.vae.encode_first_stage(mel[[m]])) for m in range(mel.shape[0])],0)
|
48 |
-
|
49 |
-
mel = self.vae.decode_first_stage(latents)
|
50 |
-
audio = self.vae.decode_to_waveform(mel)
|
51 |
-
audio = torch.from_numpy(audio)
|
52 |
-
|
53 |
-
orig_samples = orig_samples[...,0:int(duration * 3 * 48000)]
|
54 |
-
if(orig_samples.shape[-1]<audio.shape[-1]):
|
55 |
-
orig_samples = torch.cat([orig_samples, torch.zeros(orig_samples.shape[0], audio.shape[-1]-orig_samples.shape[-1], dtype=orig_samples.dtype, device=orig_samples.device)],-1)
|
56 |
-
else:
|
57 |
-
orig_samples = orig_samples[:,0:audio.shape[-1]]
|
58 |
-
output = torch.cat([orig_samples.detach().cpu(),audio.detach().cpu()],0)
|
59 |
-
return output
|
|
|
1 |
+
import json
|
2 |
+
import torch
|
3 |
+
from tqdm import tqdm
|
4 |
+
import torchaudio
|
5 |
+
import librosa
|
6 |
+
import os
|
7 |
+
import math
|
8 |
+
import numpy as np
|
9 |
+
from get_melvaehifigan48k import build_pretrained_models
|
10 |
+
import tools.torch_tools as torch_tools
|
11 |
+
|
12 |
+
class Tango:
|
13 |
+
def __init__(self, \
|
14 |
+
device="cuda:0"):
|
15 |
+
|
16 |
+
self.sample_rate = 48000
|
17 |
+
self.device = device
|
18 |
+
|
19 |
+
self.vae, self.stft = build_pretrained_models()
|
20 |
+
self.vae, self.stft = self.vae.eval().to(device), self.stft.eval().to(device)
|
21 |
+
|
22 |
+
def mel_spectrogram_to_waveform(self, mel_spectrogram):
|
23 |
+
if mel_spectrogram.dim() == 4:
|
24 |
+
mel_spectrogram = mel_spectrogram.squeeze(1)
|
25 |
+
|
26 |
+
waveform = self.vocoder(mel_spectrogram)
|
27 |
+
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
28 |
+
waveform = waveform.cpu().float()
|
29 |
+
return waveform
|
30 |
+
|
31 |
+
def sound2sound_generate_longterm(self, fname, batch_size=1, duration=10.24, steps=200, disable_progress=False):
|
32 |
+
""" Genrate audio without condition. """
|
33 |
+
num_frames = math.ceil(duration * 100. / 8)
|
34 |
+
with torch.no_grad():
|
35 |
+
orig_samples, fs = torchaudio.load(fname)
|
36 |
+
if(orig_samples.shape[-1]<int(duration*48000*3)):
|
37 |
+
orig_samples = orig_samples.repeat(1,math.ceil(int(duration*48000*3)/float(orig_samples.shape[-1])))
|
38 |
+
# orig_samples = torch.cat([torch.zeros(orig_samples.shape[0], int(duration * fs)//2, dtype=orig_samples.dtype, device=orig_samples.device), orig_samples, torch.zeros(orig_samples.shape[0], int(duration * fs)//2, dtype=orig_samples.dtype, device=orig_samples.device)], -1).to(self.device)
|
39 |
+
orig_samples = torch.cat([orig_samples, torch.zeros(orig_samples.shape[0], int(duration * fs)//2, dtype=orig_samples.dtype, device=orig_samples.device)], -1).to(self.device)
|
40 |
+
if(fs!=48000):orig_samples = torchaudio.functional.resample(orig_samples, fs, 48000)
|
41 |
+
# resampled_audios = orig_samples[[0],int(4.64*48000):int(35.36*48000)+480].clamp(-1,1)
|
42 |
+
resampled_audios = orig_samples[[0],int(0*48000):int(duration*3*48000)+480].clamp(-1,1)
|
43 |
+
orig_samples = orig_samples[[0],:]
|
44 |
+
|
45 |
+
mel, _, _ = torch_tools.wav_to_fbank2(resampled_audios, -1, fn_STFT=self.stft)
|
46 |
+
mel = mel.unsqueeze(1).to(self.device)
|
47 |
+
latents = torch.cat([self.vae.get_first_stage_encoding(self.vae.encode_first_stage(mel[[m]])) for m in range(mel.shape[0])],0)
|
48 |
+
|
49 |
+
mel = self.vae.decode_first_stage(latents)
|
50 |
+
audio = self.vae.decode_to_waveform(mel)
|
51 |
+
audio = torch.from_numpy(audio)
|
52 |
+
|
53 |
+
orig_samples = orig_samples[...,0:int(duration * 3 * 48000)]
|
54 |
+
if(orig_samples.shape[-1]<audio.shape[-1]):
|
55 |
+
orig_samples = torch.cat([orig_samples, torch.zeros(orig_samples.shape[0], audio.shape[-1]-orig_samples.shape[-1], dtype=orig_samples.dtype, device=orig_samples.device)],-1)
|
56 |
+
else:
|
57 |
+
orig_samples = orig_samples[:,0:audio.shape[-1]]
|
58 |
+
output = torch.cat([orig_samples.detach().cpu(),audio.detach().cpu()],0)
|
59 |
+
return output
|
codeclm/tokenizer/Flow1dVAE/tools/infer_vaehifigan48k_soundmusic.py
CHANGED
@@ -1,61 +1,61 @@
|
|
1 |
-
import json
|
2 |
-
import torch
|
3 |
-
from tqdm import tqdm
|
4 |
-
import torchaudio
|
5 |
-
import librosa
|
6 |
-
import os
|
7 |
-
import math
|
8 |
-
import numpy as np
|
9 |
-
from get_melvaehifigan48k import build_pretrained_models
|
10 |
-
import tools.torch_tools as torch_tools
|
11 |
-
|
12 |
-
class Tango:
|
13 |
-
def __init__(self, \
|
14 |
-
device="cuda:0"):
|
15 |
-
|
16 |
-
self.sample_rate = 48000
|
17 |
-
self.device = device
|
18 |
-
|
19 |
-
self.vae, self.stft = build_pretrained_models()
|
20 |
-
self.vae, self.stft = self.vae.eval().to(device), self.stft.eval().to(device)
|
21 |
-
|
22 |
-
# print(sum(p.numel() for p in self.vae.parameters()));exit()
|
23 |
-
|
24 |
-
def mel_spectrogram_to_waveform(self, mel_spectrogram):
|
25 |
-
if mel_spectrogram.dim() == 4:
|
26 |
-
mel_spectrogram = mel_spectrogram.squeeze(1)
|
27 |
-
|
28 |
-
waveform = self.vocoder(mel_spectrogram)
|
29 |
-
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
30 |
-
waveform = waveform.cpu().float()
|
31 |
-
return waveform
|
32 |
-
|
33 |
-
def sound2sound_generate_longterm(self, fname, batch_size=1, duration=10.24, steps=200, disable_progress=False):
|
34 |
-
""" Genrate audio without condition. """
|
35 |
-
num_frames = math.ceil(duration * 100. / 8)
|
36 |
-
with torch.no_grad():
|
37 |
-
orig_samples, fs = torchaudio.load(fname)
|
38 |
-
if(orig_samples.shape[-1]<int(duration*48000)):
|
39 |
-
orig_samples = orig_samples.repeat(1,math.ceil(int(duration*48000)/float(orig_samples.shape[-1])))
|
40 |
-
# orig_samples = torch.cat([torch.zeros(orig_samples.shape[0], int(duration * fs)//2, dtype=orig_samples.dtype, device=orig_samples.device), orig_samples, torch.zeros(orig_samples.shape[0], int(duration * fs)//2, dtype=orig_samples.dtype, device=orig_samples.device)], -1).to(self.device)
|
41 |
-
orig_samples = torch.cat([orig_samples, torch.zeros(orig_samples.shape[0], int(duration * fs)//2, dtype=orig_samples.dtype, device=orig_samples.device)], -1).to(self.device)
|
42 |
-
if(fs!=48000):orig_samples = torchaudio.functional.resample(orig_samples, fs, 48000)
|
43 |
-
# resampled_audios = orig_samples[[0],int(4.64*48000):int(35.36*48000)+480].clamp(-1,1)
|
44 |
-
resampled_audios = orig_samples[[0],int(0*48000):int(duration*48000)+480].clamp(-1,1)
|
45 |
-
orig_samples = orig_samples[[0],:]
|
46 |
-
|
47 |
-
mel, _, _ = torch_tools.wav_to_fbank2(resampled_audios, -1, fn_STFT=self.stft)
|
48 |
-
mel = mel.unsqueeze(1).to(self.device)
|
49 |
-
latents = torch.cat([self.vae.get_first_stage_encoding(self.vae.encode_first_stage(mel[[m]])) for m in range(mel.shape[0])],0)
|
50 |
-
|
51 |
-
mel = self.vae.decode_first_stage(latents)
|
52 |
-
audio = self.vae.decode_to_waveform(mel)
|
53 |
-
audio = torch.from_numpy(audio)
|
54 |
-
|
55 |
-
orig_samples = orig_samples[...,0:int(duration * 48000)]
|
56 |
-
if(orig_samples.shape[-1]<audio.shape[-1]):
|
57 |
-
orig_samples = torch.cat([orig_samples, torch.zeros(orig_samples.shape[0], audio.shape[-1]-orig_samples.shape[-1], dtype=orig_samples.dtype, device=orig_samples.device)],-1)
|
58 |
-
else:
|
59 |
-
orig_samples = orig_samples[:,0:audio.shape[-1]]
|
60 |
-
output = torch.cat([orig_samples.detach().cpu(),audio.detach().cpu()],0)
|
61 |
-
return output
|
|
|
1 |
+
import json
|
2 |
+
import torch
|
3 |
+
from tqdm import tqdm
|
4 |
+
import torchaudio
|
5 |
+
import librosa
|
6 |
+
import os
|
7 |
+
import math
|
8 |
+
import numpy as np
|
9 |
+
from get_melvaehifigan48k import build_pretrained_models
|
10 |
+
import tools.torch_tools as torch_tools
|
11 |
+
|
12 |
+
class Tango:
|
13 |
+
def __init__(self, \
|
14 |
+
device="cuda:0"):
|
15 |
+
|
16 |
+
self.sample_rate = 48000
|
17 |
+
self.device = device
|
18 |
+
|
19 |
+
self.vae, self.stft = build_pretrained_models()
|
20 |
+
self.vae, self.stft = self.vae.eval().to(device), self.stft.eval().to(device)
|
21 |
+
|
22 |
+
# print(sum(p.numel() for p in self.vae.parameters()));exit()
|
23 |
+
|
24 |
+
def mel_spectrogram_to_waveform(self, mel_spectrogram):
|
25 |
+
if mel_spectrogram.dim() == 4:
|
26 |
+
mel_spectrogram = mel_spectrogram.squeeze(1)
|
27 |
+
|
28 |
+
waveform = self.vocoder(mel_spectrogram)
|
29 |
+
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
30 |
+
waveform = waveform.cpu().float()
|
31 |
+
return waveform
|
32 |
+
|
33 |
+
def sound2sound_generate_longterm(self, fname, batch_size=1, duration=10.24, steps=200, disable_progress=False):
|
34 |
+
""" Genrate audio without condition. """
|
35 |
+
num_frames = math.ceil(duration * 100. / 8)
|
36 |
+
with torch.no_grad():
|
37 |
+
orig_samples, fs = torchaudio.load(fname)
|
38 |
+
if(orig_samples.shape[-1]<int(duration*48000)):
|
39 |
+
orig_samples = orig_samples.repeat(1,math.ceil(int(duration*48000)/float(orig_samples.shape[-1])))
|
40 |
+
# orig_samples = torch.cat([torch.zeros(orig_samples.shape[0], int(duration * fs)//2, dtype=orig_samples.dtype, device=orig_samples.device), orig_samples, torch.zeros(orig_samples.shape[0], int(duration * fs)//2, dtype=orig_samples.dtype, device=orig_samples.device)], -1).to(self.device)
|
41 |
+
orig_samples = torch.cat([orig_samples, torch.zeros(orig_samples.shape[0], int(duration * fs)//2, dtype=orig_samples.dtype, device=orig_samples.device)], -1).to(self.device)
|
42 |
+
if(fs!=48000):orig_samples = torchaudio.functional.resample(orig_samples, fs, 48000)
|
43 |
+
# resampled_audios = orig_samples[[0],int(4.64*48000):int(35.36*48000)+480].clamp(-1,1)
|
44 |
+
resampled_audios = orig_samples[[0],int(0*48000):int(duration*48000)+480].clamp(-1,1)
|
45 |
+
orig_samples = orig_samples[[0],:]
|
46 |
+
|
47 |
+
mel, _, _ = torch_tools.wav_to_fbank2(resampled_audios, -1, fn_STFT=self.stft)
|
48 |
+
mel = mel.unsqueeze(1).to(self.device)
|
49 |
+
latents = torch.cat([self.vae.get_first_stage_encoding(self.vae.encode_first_stage(mel[[m]])) for m in range(mel.shape[0])],0)
|
50 |
+
|
51 |
+
mel = self.vae.decode_first_stage(latents)
|
52 |
+
audio = self.vae.decode_to_waveform(mel)
|
53 |
+
audio = torch.from_numpy(audio)
|
54 |
+
|
55 |
+
orig_samples = orig_samples[...,0:int(duration * 48000)]
|
56 |
+
if(orig_samples.shape[-1]<audio.shape[-1]):
|
57 |
+
orig_samples = torch.cat([orig_samples, torch.zeros(orig_samples.shape[0], audio.shape[-1]-orig_samples.shape[-1], dtype=orig_samples.dtype, device=orig_samples.device)],-1)
|
58 |
+
else:
|
59 |
+
orig_samples = orig_samples[:,0:audio.shape[-1]]
|
60 |
+
output = torch.cat([orig_samples.detach().cpu(),audio.detach().cpu()],0)
|
61 |
+
return output
|
codeclm/tokenizer/Flow1dVAE/tools/infer_vaehifigan48k_speech.py
CHANGED
@@ -1,58 +1,58 @@
|
|
1 |
-
import json
|
2 |
-
import torch
|
3 |
-
from tqdm import tqdm
|
4 |
-
import torchaudio
|
5 |
-
import librosa
|
6 |
-
import os
|
7 |
-
import math
|
8 |
-
import numpy as np
|
9 |
-
from get_melvaehifigan48k import build_pretrained_models
|
10 |
-
import tools.torch_tools as torch_tools
|
11 |
-
|
12 |
-
class Tango:
|
13 |
-
def __init__(self, \
|
14 |
-
device="cuda:0"):
|
15 |
-
|
16 |
-
self.sample_rate = 48000
|
17 |
-
self.device = device
|
18 |
-
|
19 |
-
self.vae, self.stft = build_pretrained_models()
|
20 |
-
self.vae, self.stft = self.vae.eval().to(device), self.stft.eval().to(device)
|
21 |
-
|
22 |
-
def mel_spectrogram_to_waveform(self, mel_spectrogram):
|
23 |
-
if mel_spectrogram.dim() == 4:
|
24 |
-
mel_spectrogram = mel_spectrogram.squeeze(1)
|
25 |
-
|
26 |
-
waveform = self.vocoder(mel_spectrogram)
|
27 |
-
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
28 |
-
waveform = waveform.cpu().float()
|
29 |
-
return waveform
|
30 |
-
|
31 |
-
def sound2sound_generate_longterm(self, fname, batch_size=1, duration=10.24, steps=200, disable_progress=False):
|
32 |
-
""" Genrate audio without condition. """
|
33 |
-
num_frames = math.ceil(duration * 100. / 8)
|
34 |
-
with torch.no_grad():
|
35 |
-
orig_samples, fs = torchaudio.load(fname)
|
36 |
-
if(orig_samples.shape[-1]<int(duration*48000)):
|
37 |
-
orig_samples = orig_samples.repeat(1,math.ceil(int(duration*48000)/float(orig_samples.shape[-1])))
|
38 |
-
# orig_samples = torch.cat([torch.zeros(orig_samples.shape[0], int(duration * fs)//2, dtype=orig_samples.dtype, device=orig_samples.device), orig_samples, torch.zeros(orig_samples.shape[0], int(duration * fs)//2, dtype=orig_samples.dtype, device=orig_samples.device)], -1).to(self.device)
|
39 |
-
orig_samples = torch.cat([orig_samples, torch.zeros(orig_samples.shape[0], int(duration * fs)//2, dtype=orig_samples.dtype, device=orig_samples.device)], -1).to(self.device)
|
40 |
-
if(fs!=48000):orig_samples = torchaudio.functional.resample(orig_samples, fs, 48000)
|
41 |
-
# resampled_audios = orig_samples[[0],int(4.64*48000):int(35.36*48000)+480].clamp(-1,1)
|
42 |
-
resampled_audios = orig_samples[[0],0:int(duration*48000)+480].clamp(-1,1)
|
43 |
-
orig_samples = orig_samples[[0],0:int(duration*48000)]
|
44 |
-
|
45 |
-
mel, _, _ = torch_tools.wav_to_fbank2(resampled_audios, -1, fn_STFT=self.stft)
|
46 |
-
mel = mel.unsqueeze(1).to(self.device)
|
47 |
-
latents = torch.cat([self.vae.get_first_stage_encoding(self.vae.encode_first_stage(mel[[m]])) for m in range(mel.shape[0])],0)
|
48 |
-
|
49 |
-
mel = self.vae.decode_first_stage(latents)
|
50 |
-
audio = self.vae.decode_to_waveform(mel)
|
51 |
-
audio = torch.from_numpy(audio)
|
52 |
-
|
53 |
-
if(orig_samples.shape[-1]<audio.shape[-1]):
|
54 |
-
orig_samples = torch.cat([orig_samples, torch.zeros(orig_samples.shape[0], audio.shape[-1]-orig_samples.shape[-1], dtype=orig_samples.dtype, device=orig_samples.device)],-1)
|
55 |
-
else:
|
56 |
-
orig_samples = orig_samples[:,0:audio.shape[-1]]
|
57 |
-
output = torch.cat([orig_samples.detach().cpu(),audio.detach().cpu()],0)
|
58 |
-
return output
|
|
|
1 |
+
import json
|
2 |
+
import torch
|
3 |
+
from tqdm import tqdm
|
4 |
+
import torchaudio
|
5 |
+
import librosa
|
6 |
+
import os
|
7 |
+
import math
|
8 |
+
import numpy as np
|
9 |
+
from get_melvaehifigan48k import build_pretrained_models
|
10 |
+
import tools.torch_tools as torch_tools
|
11 |
+
|
12 |
+
class Tango:
|
13 |
+
def __init__(self, \
|
14 |
+
device="cuda:0"):
|
15 |
+
|
16 |
+
self.sample_rate = 48000
|
17 |
+
self.device = device
|
18 |
+
|
19 |
+
self.vae, self.stft = build_pretrained_models()
|
20 |
+
self.vae, self.stft = self.vae.eval().to(device), self.stft.eval().to(device)
|
21 |
+
|
22 |
+
def mel_spectrogram_to_waveform(self, mel_spectrogram):
|
23 |
+
if mel_spectrogram.dim() == 4:
|
24 |
+
mel_spectrogram = mel_spectrogram.squeeze(1)
|
25 |
+
|
26 |
+
waveform = self.vocoder(mel_spectrogram)
|
27 |
+
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
28 |
+
waveform = waveform.cpu().float()
|
29 |
+
return waveform
|
30 |
+
|
31 |
+
def sound2sound_generate_longterm(self, fname, batch_size=1, duration=10.24, steps=200, disable_progress=False):
|
32 |
+
""" Genrate audio without condition. """
|
33 |
+
num_frames = math.ceil(duration * 100. / 8)
|
34 |
+
with torch.no_grad():
|
35 |
+
orig_samples, fs = torchaudio.load(fname)
|
36 |
+
if(orig_samples.shape[-1]<int(duration*48000)):
|
37 |
+
orig_samples = orig_samples.repeat(1,math.ceil(int(duration*48000)/float(orig_samples.shape[-1])))
|
38 |
+
# orig_samples = torch.cat([torch.zeros(orig_samples.shape[0], int(duration * fs)//2, dtype=orig_samples.dtype, device=orig_samples.device), orig_samples, torch.zeros(orig_samples.shape[0], int(duration * fs)//2, dtype=orig_samples.dtype, device=orig_samples.device)], -1).to(self.device)
|
39 |
+
orig_samples = torch.cat([orig_samples, torch.zeros(orig_samples.shape[0], int(duration * fs)//2, dtype=orig_samples.dtype, device=orig_samples.device)], -1).to(self.device)
|
40 |
+
if(fs!=48000):orig_samples = torchaudio.functional.resample(orig_samples, fs, 48000)
|
41 |
+
# resampled_audios = orig_samples[[0],int(4.64*48000):int(35.36*48000)+480].clamp(-1,1)
|
42 |
+
resampled_audios = orig_samples[[0],0:int(duration*48000)+480].clamp(-1,1)
|
43 |
+
orig_samples = orig_samples[[0],0:int(duration*48000)]
|
44 |
+
|
45 |
+
mel, _, _ = torch_tools.wav_to_fbank2(resampled_audios, -1, fn_STFT=self.stft)
|
46 |
+
mel = mel.unsqueeze(1).to(self.device)
|
47 |
+
latents = torch.cat([self.vae.get_first_stage_encoding(self.vae.encode_first_stage(mel[[m]])) for m in range(mel.shape[0])],0)
|
48 |
+
|
49 |
+
mel = self.vae.decode_first_stage(latents)
|
50 |
+
audio = self.vae.decode_to_waveform(mel)
|
51 |
+
audio = torch.from_numpy(audio)
|
52 |
+
|
53 |
+
if(orig_samples.shape[-1]<audio.shape[-1]):
|
54 |
+
orig_samples = torch.cat([orig_samples, torch.zeros(orig_samples.shape[0], audio.shape[-1]-orig_samples.shape[-1], dtype=orig_samples.dtype, device=orig_samples.device)],-1)
|
55 |
+
else:
|
56 |
+
orig_samples = orig_samples[:,0:audio.shape[-1]]
|
57 |
+
output = torch.cat([orig_samples.detach().cpu(),audio.detach().cpu()],0)
|
58 |
+
return output
|
codeclm/tokenizer/Flow1dVAE/tools/infer_vaehifigan48k_vocal.py
CHANGED
@@ -1,59 +1,59 @@
|
|
1 |
-
import json
|
2 |
-
import torch
|
3 |
-
from tqdm import tqdm
|
4 |
-
import torchaudio
|
5 |
-
import librosa
|
6 |
-
import os
|
7 |
-
import math
|
8 |
-
import numpy as np
|
9 |
-
from get_melvaehifigan48k import build_pretrained_models
|
10 |
-
import tools.torch_tools as torch_tools
|
11 |
-
|
12 |
-
class Tango:
|
13 |
-
def __init__(self, \
|
14 |
-
device="cuda:0"):
|
15 |
-
|
16 |
-
self.sample_rate = 48000
|
17 |
-
self.device = device
|
18 |
-
|
19 |
-
self.vae, self.stft = build_pretrained_models()
|
20 |
-
self.vae, self.stft = self.vae.eval().to(device), self.stft.eval().to(device)
|
21 |
-
|
22 |
-
def mel_spectrogram_to_waveform(self, mel_spectrogram):
|
23 |
-
if mel_spectrogram.dim() == 4:
|
24 |
-
mel_spectrogram = mel_spectrogram.squeeze(1)
|
25 |
-
|
26 |
-
waveform = self.vocoder(mel_spectrogram)
|
27 |
-
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
28 |
-
waveform = waveform.cpu().float()
|
29 |
-
return waveform
|
30 |
-
|
31 |
-
def sound2sound_generate_longterm(self, fname, batch_size=1, duration=10.24, steps=200, disable_progress=False):
|
32 |
-
""" Genrate audio without condition. """
|
33 |
-
num_frames = math.ceil(duration * 100. / 8)
|
34 |
-
with torch.no_grad():
|
35 |
-
orig_samples, fs = torchaudio.load(fname)
|
36 |
-
if(orig_samples.shape[-1]<int(duration*48000*2)):
|
37 |
-
orig_samples = torch.cat([orig_samples, torch.zeros(orig_samples.shape[0], int(duration*48000*2+480)-orig_samples.shape[-1], \
|
38 |
-
dtype=orig_samples.dtype, device=orig_samples.device)], -1)
|
39 |
-
# orig_samples = torch.cat([torch.zeros(orig_samples.shape[0], int(duration * fs)//2, dtype=orig_samples.dtype, device=orig_samples.device), orig_samples, torch.zeros(orig_samples.shape[0], int(duration * fs)//2, dtype=orig_samples.dtype, device=orig_samples.device)], -1).to(self.device)
|
40 |
-
orig_samples = torch.cat([orig_samples, torch.zeros(orig_samples.shape[0], int(duration * fs)//2, dtype=orig_samples.dtype, device=orig_samples.device)], -1).to(self.device)
|
41 |
-
if(fs!=48000):orig_samples = torchaudio.functional.resample(orig_samples, fs, 48000)
|
42 |
-
# resampled_audios = orig_samples[[0],int(4.64*48000):int(35.36*48000)+480].clamp(-1,1)
|
43 |
-
resampled_audios = orig_samples[[0],0:int(duration*2*48000)+480].clamp(-1,1)
|
44 |
-
orig_samples = orig_samples[[0],0:int(duration*2*48000)]
|
45 |
-
|
46 |
-
mel, _, _ = torch_tools.wav_to_fbank2(resampled_audios, -1, fn_STFT=self.stft)
|
47 |
-
mel = mel.unsqueeze(1).to(self.device)
|
48 |
-
latents = torch.cat([self.vae.get_first_stage_encoding(self.vae.encode_first_stage(mel[[m]])) for m in range(mel.shape[0])],0)
|
49 |
-
|
50 |
-
mel = self.vae.decode_first_stage(latents)
|
51 |
-
audio = self.vae.decode_to_waveform(mel)
|
52 |
-
audio = torch.from_numpy(audio)
|
53 |
-
|
54 |
-
if(orig_samples.shape[-1]<audio.shape[-1]):
|
55 |
-
orig_samples = torch.cat([orig_samples, torch.zeros(orig_samples.shape[0], audio.shape[-1]-orig_samples.shape[-1], dtype=orig_samples.dtype, device=orig_samples.device)],-1)
|
56 |
-
else:
|
57 |
-
orig_samples = orig_samples[:,0:audio.shape[-1]]
|
58 |
-
output = torch.cat([orig_samples.detach().cpu(),audio.detach().cpu()],0)
|
59 |
-
return output
|
|
|
1 |
+
import json
|
2 |
+
import torch
|
3 |
+
from tqdm import tqdm
|
4 |
+
import torchaudio
|
5 |
+
import librosa
|
6 |
+
import os
|
7 |
+
import math
|
8 |
+
import numpy as np
|
9 |
+
from get_melvaehifigan48k import build_pretrained_models
|
10 |
+
import tools.torch_tools as torch_tools
|
11 |
+
|
12 |
+
class Tango:
|
13 |
+
def __init__(self, \
|
14 |
+
device="cuda:0"):
|
15 |
+
|
16 |
+
self.sample_rate = 48000
|
17 |
+
self.device = device
|
18 |
+
|
19 |
+
self.vae, self.stft = build_pretrained_models()
|
20 |
+
self.vae, self.stft = self.vae.eval().to(device), self.stft.eval().to(device)
|
21 |
+
|
22 |
+
def mel_spectrogram_to_waveform(self, mel_spectrogram):
|
23 |
+
if mel_spectrogram.dim() == 4:
|
24 |
+
mel_spectrogram = mel_spectrogram.squeeze(1)
|
25 |
+
|
26 |
+
waveform = self.vocoder(mel_spectrogram)
|
27 |
+
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
28 |
+
waveform = waveform.cpu().float()
|
29 |
+
return waveform
|
30 |
+
|
31 |
+
def sound2sound_generate_longterm(self, fname, batch_size=1, duration=10.24, steps=200, disable_progress=False):
|
32 |
+
""" Genrate audio without condition. """
|
33 |
+
num_frames = math.ceil(duration * 100. / 8)
|
34 |
+
with torch.no_grad():
|
35 |
+
orig_samples, fs = torchaudio.load(fname)
|
36 |
+
if(orig_samples.shape[-1]<int(duration*48000*2)):
|
37 |
+
orig_samples = torch.cat([orig_samples, torch.zeros(orig_samples.shape[0], int(duration*48000*2+480)-orig_samples.shape[-1], \
|
38 |
+
dtype=orig_samples.dtype, device=orig_samples.device)], -1)
|
39 |
+
# orig_samples = torch.cat([torch.zeros(orig_samples.shape[0], int(duration * fs)//2, dtype=orig_samples.dtype, device=orig_samples.device), orig_samples, torch.zeros(orig_samples.shape[0], int(duration * fs)//2, dtype=orig_samples.dtype, device=orig_samples.device)], -1).to(self.device)
|
40 |
+
orig_samples = torch.cat([orig_samples, torch.zeros(orig_samples.shape[0], int(duration * fs)//2, dtype=orig_samples.dtype, device=orig_samples.device)], -1).to(self.device)
|
41 |
+
if(fs!=48000):orig_samples = torchaudio.functional.resample(orig_samples, fs, 48000)
|
42 |
+
# resampled_audios = orig_samples[[0],int(4.64*48000):int(35.36*48000)+480].clamp(-1,1)
|
43 |
+
resampled_audios = orig_samples[[0],0:int(duration*2*48000)+480].clamp(-1,1)
|
44 |
+
orig_samples = orig_samples[[0],0:int(duration*2*48000)]
|
45 |
+
|
46 |
+
mel, _, _ = torch_tools.wav_to_fbank2(resampled_audios, -1, fn_STFT=self.stft)
|
47 |
+
mel = mel.unsqueeze(1).to(self.device)
|
48 |
+
latents = torch.cat([self.vae.get_first_stage_encoding(self.vae.encode_first_stage(mel[[m]])) for m in range(mel.shape[0])],0)
|
49 |
+
|
50 |
+
mel = self.vae.decode_first_stage(latents)
|
51 |
+
audio = self.vae.decode_to_waveform(mel)
|
52 |
+
audio = torch.from_numpy(audio)
|
53 |
+
|
54 |
+
if(orig_samples.shape[-1]<audio.shape[-1]):
|
55 |
+
orig_samples = torch.cat([orig_samples, torch.zeros(orig_samples.shape[0], audio.shape[-1]-orig_samples.shape[-1], dtype=orig_samples.dtype, device=orig_samples.device)],-1)
|
56 |
+
else:
|
57 |
+
orig_samples = orig_samples[:,0:audio.shape[-1]]
|
58 |
+
output = torch.cat([orig_samples.detach().cpu(),audio.detach().cpu()],0)
|
59 |
+
return output
|
codeclm/tokenizer/Flow1dVAE/tools/mix.py
CHANGED
@@ -1,51 +1,51 @@
|
|
1 |
-
import numpy as np
|
2 |
-
|
3 |
-
|
4 |
-
def a_weight(fs, n_fft, min_db=-80.0):
|
5 |
-
freq = np.linspace(0, fs // 2, n_fft // 2 + 1)
|
6 |
-
freq_sq = np.power(freq, 2)
|
7 |
-
freq_sq[0] = 1.0
|
8 |
-
weight = 2.0 + 20.0 * (2 * np.log10(12194) + 2 * np.log10(freq_sq)
|
9 |
-
- np.log10(freq_sq + 12194 ** 2)
|
10 |
-
- np.log10(freq_sq + 20.6 ** 2)
|
11 |
-
- 0.5 * np.log10(freq_sq + 107.7 ** 2)
|
12 |
-
- 0.5 * np.log10(freq_sq + 737.9 ** 2))
|
13 |
-
weight = np.maximum(weight, min_db)
|
14 |
-
|
15 |
-
return weight
|
16 |
-
|
17 |
-
|
18 |
-
def compute_gain(sound, fs, min_db=-80.0, mode="A_weighting"):
|
19 |
-
if fs == 16000:
|
20 |
-
n_fft = 2048
|
21 |
-
elif fs == 44100:
|
22 |
-
n_fft = 4096
|
23 |
-
else:
|
24 |
-
raise Exception("Invalid fs {}".format(fs))
|
25 |
-
stride = n_fft // 2
|
26 |
-
|
27 |
-
gain = []
|
28 |
-
for i in range(0, len(sound) - n_fft + 1, stride):
|
29 |
-
if mode == "RMSE":
|
30 |
-
g = np.mean(sound[i: i + n_fft] ** 2)
|
31 |
-
elif mode == "A_weighting":
|
32 |
-
spec = np.fft.rfft(np.hanning(n_fft + 1)[:-1] * sound[i: i + n_fft])
|
33 |
-
power_spec = np.abs(spec) ** 2
|
34 |
-
a_weighted_spec = power_spec * np.power(10, a_weight(fs, n_fft) / 10)
|
35 |
-
g = np.sum(a_weighted_spec)
|
36 |
-
else:
|
37 |
-
raise Exception("Invalid mode {}".format(mode))
|
38 |
-
gain.append(g)
|
39 |
-
|
40 |
-
gain = np.array(gain)
|
41 |
-
gain = np.maximum(gain, np.power(10, min_db / 10))
|
42 |
-
gain_db = 10 * np.log10(gain)
|
43 |
-
return gain_db
|
44 |
-
|
45 |
-
|
46 |
-
def mix(sound1, sound2, r, fs):
|
47 |
-
gain1 = np.max(compute_gain(sound1, fs)) # Decibel
|
48 |
-
gain2 = np.max(compute_gain(sound2, fs))
|
49 |
-
t = 1.0 / (1 + np.power(10, (gain1 - gain2) / 20.) * (1 - r) / r)
|
50 |
-
sound = ((sound1 * t + sound2 * (1 - t)) / np.sqrt(t ** 2 + (1 - t) ** 2))
|
51 |
return sound
|
|
|
1 |
+
import numpy as np
|
2 |
+
|
3 |
+
|
4 |
+
def a_weight(fs, n_fft, min_db=-80.0):
|
5 |
+
freq = np.linspace(0, fs // 2, n_fft // 2 + 1)
|
6 |
+
freq_sq = np.power(freq, 2)
|
7 |
+
freq_sq[0] = 1.0
|
8 |
+
weight = 2.0 + 20.0 * (2 * np.log10(12194) + 2 * np.log10(freq_sq)
|
9 |
+
- np.log10(freq_sq + 12194 ** 2)
|
10 |
+
- np.log10(freq_sq + 20.6 ** 2)
|
11 |
+
- 0.5 * np.log10(freq_sq + 107.7 ** 2)
|
12 |
+
- 0.5 * np.log10(freq_sq + 737.9 ** 2))
|
13 |
+
weight = np.maximum(weight, min_db)
|
14 |
+
|
15 |
+
return weight
|
16 |
+
|
17 |
+
|
18 |
+
def compute_gain(sound, fs, min_db=-80.0, mode="A_weighting"):
|
19 |
+
if fs == 16000:
|
20 |
+
n_fft = 2048
|
21 |
+
elif fs == 44100:
|
22 |
+
n_fft = 4096
|
23 |
+
else:
|
24 |
+
raise Exception("Invalid fs {}".format(fs))
|
25 |
+
stride = n_fft // 2
|
26 |
+
|
27 |
+
gain = []
|
28 |
+
for i in range(0, len(sound) - n_fft + 1, stride):
|
29 |
+
if mode == "RMSE":
|
30 |
+
g = np.mean(sound[i: i + n_fft] ** 2)
|
31 |
+
elif mode == "A_weighting":
|
32 |
+
spec = np.fft.rfft(np.hanning(n_fft + 1)[:-1] * sound[i: i + n_fft])
|
33 |
+
power_spec = np.abs(spec) ** 2
|
34 |
+
a_weighted_spec = power_spec * np.power(10, a_weight(fs, n_fft) / 10)
|
35 |
+
g = np.sum(a_weighted_spec)
|
36 |
+
else:
|
37 |
+
raise Exception("Invalid mode {}".format(mode))
|
38 |
+
gain.append(g)
|
39 |
+
|
40 |
+
gain = np.array(gain)
|
41 |
+
gain = np.maximum(gain, np.power(10, min_db / 10))
|
42 |
+
gain_db = 10 * np.log10(gain)
|
43 |
+
return gain_db
|
44 |
+
|
45 |
+
|
46 |
+
def mix(sound1, sound2, r, fs):
|
47 |
+
gain1 = np.max(compute_gain(sound1, fs)) # Decibel
|
48 |
+
gain2 = np.max(compute_gain(sound2, fs))
|
49 |
+
t = 1.0 / (1 + np.power(10, (gain1 - gain2) / 20.) * (1 - r) / r)
|
50 |
+
sound = ((sound1 * t + sound2 * (1 - t)) / np.sqrt(t ** 2 + (1 - t) ** 2))
|
51 |
return sound
|
codeclm/tokenizer/Flow1dVAE/tools/torch_tools.py
CHANGED
@@ -1,143 +1,143 @@
|
|
1 |
-
import torch
|
2 |
-
import torchaudio
|
3 |
-
import random
|
4 |
-
import itertools
|
5 |
-
import numpy as np
|
6 |
-
from tools.mix import mix
|
7 |
-
|
8 |
-
|
9 |
-
def normalize_wav(waveform):
|
10 |
-
waveform = waveform - torch.mean(waveform)
|
11 |
-
waveform = waveform / (torch.max(torch.abs(waveform)) + 1e-8)
|
12 |
-
return waveform * 0.5
|
13 |
-
|
14 |
-
|
15 |
-
def pad_wav(waveform, segment_length):
|
16 |
-
waveform_length = len(waveform)
|
17 |
-
|
18 |
-
if segment_length is None or waveform_length == segment_length:
|
19 |
-
return waveform
|
20 |
-
elif waveform_length > segment_length:
|
21 |
-
return waveform[:segment_length]
|
22 |
-
else:
|
23 |
-
pad_wav = torch.zeros(segment_length - waveform_length).to(waveform.device)
|
24 |
-
waveform = torch.cat([waveform, pad_wav])
|
25 |
-
return waveform
|
26 |
-
|
27 |
-
|
28 |
-
def _pad_spec(fbank, target_length=1024):
|
29 |
-
batch, n_frames, channels = fbank.shape
|
30 |
-
p = target_length - n_frames
|
31 |
-
if p > 0:
|
32 |
-
pad = torch.zeros(batch, p, channels).to(fbank.device)
|
33 |
-
fbank = torch.cat([fbank, pad], 1)
|
34 |
-
elif p < 0:
|
35 |
-
fbank = fbank[:, :target_length, :]
|
36 |
-
|
37 |
-
if channels % 2 != 0:
|
38 |
-
fbank = fbank[:, :, :-1]
|
39 |
-
|
40 |
-
return fbank
|
41 |
-
|
42 |
-
|
43 |
-
def read_wav_file(filename, segment_length):
|
44 |
-
waveform, sr = torchaudio.load(filename) # Faster!!!
|
45 |
-
waveform = torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=16000)[0]
|
46 |
-
try:
|
47 |
-
waveform = normalize_wav(waveform)
|
48 |
-
except:
|
49 |
-
print ("Exception normalizing:", filename)
|
50 |
-
waveform = torch.ones(160000)
|
51 |
-
waveform = pad_wav(waveform, segment_length).unsqueeze(0)
|
52 |
-
waveform = waveform / torch.max(torch.abs(waveform))
|
53 |
-
waveform = 0.5 * waveform
|
54 |
-
return waveform
|
55 |
-
|
56 |
-
|
57 |
-
def get_mel_from_wav(audio, _stft):
|
58 |
-
audio = torch.nan_to_num(torch.clip(audio, -1, 1))
|
59 |
-
audio = torch.autograd.Variable(audio, requires_grad=False)
|
60 |
-
melspec, log_magnitudes_stft, energy = _stft.mel_spectrogram(audio)
|
61 |
-
return melspec, log_magnitudes_stft, energy
|
62 |
-
|
63 |
-
|
64 |
-
def wav_to_fbank(paths, target_length=1024, fn_STFT=None):
|
65 |
-
assert fn_STFT is not None
|
66 |
-
|
67 |
-
waveform = torch.cat([read_wav_file(path, target_length * 160) for path in paths], 0) # hop size is 160
|
68 |
-
|
69 |
-
fbank, log_magnitudes_stft, energy = get_mel_from_wav(waveform, fn_STFT)
|
70 |
-
fbank = fbank.transpose(1, 2)
|
71 |
-
log_magnitudes_stft = log_magnitudes_stft.transpose(1, 2)
|
72 |
-
|
73 |
-
fbank, log_magnitudes_stft = _pad_spec(fbank, target_length), _pad_spec(
|
74 |
-
log_magnitudes_stft, target_length
|
75 |
-
)
|
76 |
-
|
77 |
-
return fbank, log_magnitudes_stft, waveform
|
78 |
-
|
79 |
-
def wav_to_fbank2(waveform, target_length=-1, fn_STFT=None):
|
80 |
-
assert fn_STFT is not None
|
81 |
-
|
82 |
-
fbank, log_magnitudes_stft, energy = get_mel_from_wav(waveform, fn_STFT)
|
83 |
-
fbank = fbank.transpose(1, 2)
|
84 |
-
log_magnitudes_stft = log_magnitudes_stft.transpose(1, 2)
|
85 |
-
# print(fbank.shape, log_magnitudes_stft.shape)
|
86 |
-
|
87 |
-
if(target_length>0):
|
88 |
-
fbank, log_magnitudes_stft = _pad_spec(fbank, target_length), _pad_spec(
|
89 |
-
log_magnitudes_stft, target_length
|
90 |
-
)
|
91 |
-
|
92 |
-
return fbank, log_magnitudes_stft, waveform
|
93 |
-
|
94 |
-
|
95 |
-
def uncapitalize(s):
|
96 |
-
if s:
|
97 |
-
return s[:1].lower() + s[1:]
|
98 |
-
else:
|
99 |
-
return ""
|
100 |
-
|
101 |
-
|
102 |
-
def mix_wavs_and_captions(path1, path2, caption1, caption2, target_length=1024):
|
103 |
-
sound1 = read_wav_file(path1, target_length * 160)[0].numpy()
|
104 |
-
sound2 = read_wav_file(path2, target_length * 160)[0].numpy()
|
105 |
-
mixed_sound = mix(sound1, sound2, 0.5, 16000).reshape(1, -1)
|
106 |
-
mixed_caption = "{} and {}".format(caption1, uncapitalize(caption2))
|
107 |
-
return mixed_sound, mixed_caption
|
108 |
-
|
109 |
-
|
110 |
-
def augment(paths, texts, num_items=4, target_length=1024):
|
111 |
-
mixed_sounds, mixed_captions = [], []
|
112 |
-
combinations = list(itertools.combinations(list(range(len(texts))), 2))
|
113 |
-
random.shuffle(combinations)
|
114 |
-
if len(combinations) < num_items:
|
115 |
-
selected_combinations = combinations
|
116 |
-
else:
|
117 |
-
selected_combinations = combinations[:num_items]
|
118 |
-
|
119 |
-
for (i, j) in selected_combinations:
|
120 |
-
new_sound, new_caption = mix_wavs_and_captions(paths[i], paths[j], texts[i], texts[j], target_length)
|
121 |
-
mixed_sounds.append(new_sound)
|
122 |
-
mixed_captions.append(new_caption)
|
123 |
-
|
124 |
-
waveform = torch.tensor(np.concatenate(mixed_sounds, 0))
|
125 |
-
waveform = waveform / torch.max(torch.abs(waveform))
|
126 |
-
waveform = 0.5 * waveform
|
127 |
-
|
128 |
-
return waveform, mixed_captions
|
129 |
-
|
130 |
-
|
131 |
-
def augment_wav_to_fbank(paths, texts, num_items=4, target_length=1024, fn_STFT=None):
|
132 |
-
assert fn_STFT is not None
|
133 |
-
|
134 |
-
waveform, captions = augment(paths, texts)
|
135 |
-
fbank, log_magnitudes_stft, energy = get_mel_from_wav(waveform, fn_STFT)
|
136 |
-
fbank = fbank.transpose(1, 2)
|
137 |
-
log_magnitudes_stft = log_magnitudes_stft.transpose(1, 2)
|
138 |
-
|
139 |
-
fbank, log_magnitudes_stft = _pad_spec(fbank, target_length), _pad_spec(
|
140 |
-
log_magnitudes_stft, target_length
|
141 |
-
)
|
142 |
-
|
143 |
return fbank, log_magnitudes_stft, waveform, captions
|
|
|
1 |
+
import torch
|
2 |
+
import torchaudio
|
3 |
+
import random
|
4 |
+
import itertools
|
5 |
+
import numpy as np
|
6 |
+
from tools.mix import mix
|
7 |
+
|
8 |
+
|
9 |
+
def normalize_wav(waveform):
|
10 |
+
waveform = waveform - torch.mean(waveform)
|
11 |
+
waveform = waveform / (torch.max(torch.abs(waveform)) + 1e-8)
|
12 |
+
return waveform * 0.5
|
13 |
+
|
14 |
+
|
15 |
+
def pad_wav(waveform, segment_length):
|
16 |
+
waveform_length = len(waveform)
|
17 |
+
|
18 |
+
if segment_length is None or waveform_length == segment_length:
|
19 |
+
return waveform
|
20 |
+
elif waveform_length > segment_length:
|
21 |
+
return waveform[:segment_length]
|
22 |
+
else:
|
23 |
+
pad_wav = torch.zeros(segment_length - waveform_length).to(waveform.device)
|
24 |
+
waveform = torch.cat([waveform, pad_wav])
|
25 |
+
return waveform
|
26 |
+
|
27 |
+
|
28 |
+
def _pad_spec(fbank, target_length=1024):
|
29 |
+
batch, n_frames, channels = fbank.shape
|
30 |
+
p = target_length - n_frames
|
31 |
+
if p > 0:
|
32 |
+
pad = torch.zeros(batch, p, channels).to(fbank.device)
|
33 |
+
fbank = torch.cat([fbank, pad], 1)
|
34 |
+
elif p < 0:
|
35 |
+
fbank = fbank[:, :target_length, :]
|
36 |
+
|
37 |
+
if channels % 2 != 0:
|
38 |
+
fbank = fbank[:, :, :-1]
|
39 |
+
|
40 |
+
return fbank
|
41 |
+
|
42 |
+
|
43 |
+
def read_wav_file(filename, segment_length):
|
44 |
+
waveform, sr = torchaudio.load(filename) # Faster!!!
|
45 |
+
waveform = torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=16000)[0]
|
46 |
+
try:
|
47 |
+
waveform = normalize_wav(waveform)
|
48 |
+
except:
|
49 |
+
print ("Exception normalizing:", filename)
|
50 |
+
waveform = torch.ones(160000)
|
51 |
+
waveform = pad_wav(waveform, segment_length).unsqueeze(0)
|
52 |
+
waveform = waveform / torch.max(torch.abs(waveform))
|
53 |
+
waveform = 0.5 * waveform
|
54 |
+
return waveform
|
55 |
+
|
56 |
+
|
57 |
+
def get_mel_from_wav(audio, _stft):
|
58 |
+
audio = torch.nan_to_num(torch.clip(audio, -1, 1))
|
59 |
+
audio = torch.autograd.Variable(audio, requires_grad=False)
|
60 |
+
melspec, log_magnitudes_stft, energy = _stft.mel_spectrogram(audio)
|
61 |
+
return melspec, log_magnitudes_stft, energy
|
62 |
+
|
63 |
+
|
64 |
+
def wav_to_fbank(paths, target_length=1024, fn_STFT=None):
|
65 |
+
assert fn_STFT is not None
|
66 |
+
|
67 |
+
waveform = torch.cat([read_wav_file(path, target_length * 160) for path in paths], 0) # hop size is 160
|
68 |
+
|
69 |
+
fbank, log_magnitudes_stft, energy = get_mel_from_wav(waveform, fn_STFT)
|
70 |
+
fbank = fbank.transpose(1, 2)
|
71 |
+
log_magnitudes_stft = log_magnitudes_stft.transpose(1, 2)
|
72 |
+
|
73 |
+
fbank, log_magnitudes_stft = _pad_spec(fbank, target_length), _pad_spec(
|
74 |
+
log_magnitudes_stft, target_length
|
75 |
+
)
|
76 |
+
|
77 |
+
return fbank, log_magnitudes_stft, waveform
|
78 |
+
|
79 |
+
def wav_to_fbank2(waveform, target_length=-1, fn_STFT=None):
|
80 |
+
assert fn_STFT is not None
|
81 |
+
|
82 |
+
fbank, log_magnitudes_stft, energy = get_mel_from_wav(waveform, fn_STFT)
|
83 |
+
fbank = fbank.transpose(1, 2)
|
84 |
+
log_magnitudes_stft = log_magnitudes_stft.transpose(1, 2)
|
85 |
+
# print(fbank.shape, log_magnitudes_stft.shape)
|
86 |
+
|
87 |
+
if(target_length>0):
|
88 |
+
fbank, log_magnitudes_stft = _pad_spec(fbank, target_length), _pad_spec(
|
89 |
+
log_magnitudes_stft, target_length
|
90 |
+
)
|
91 |
+
|
92 |
+
return fbank, log_magnitudes_stft, waveform
|
93 |
+
|
94 |
+
|
95 |
+
def uncapitalize(s):
|
96 |
+
if s:
|
97 |
+
return s[:1].lower() + s[1:]
|
98 |
+
else:
|
99 |
+
return ""
|
100 |
+
|
101 |
+
|
102 |
+
def mix_wavs_and_captions(path1, path2, caption1, caption2, target_length=1024):
|
103 |
+
sound1 = read_wav_file(path1, target_length * 160)[0].numpy()
|
104 |
+
sound2 = read_wav_file(path2, target_length * 160)[0].numpy()
|
105 |
+
mixed_sound = mix(sound1, sound2, 0.5, 16000).reshape(1, -1)
|
106 |
+
mixed_caption = "{} and {}".format(caption1, uncapitalize(caption2))
|
107 |
+
return mixed_sound, mixed_caption
|
108 |
+
|
109 |
+
|
110 |
+
def augment(paths, texts, num_items=4, target_length=1024):
|
111 |
+
mixed_sounds, mixed_captions = [], []
|
112 |
+
combinations = list(itertools.combinations(list(range(len(texts))), 2))
|
113 |
+
random.shuffle(combinations)
|
114 |
+
if len(combinations) < num_items:
|
115 |
+
selected_combinations = combinations
|
116 |
+
else:
|
117 |
+
selected_combinations = combinations[:num_items]
|
118 |
+
|
119 |
+
for (i, j) in selected_combinations:
|
120 |
+
new_sound, new_caption = mix_wavs_and_captions(paths[i], paths[j], texts[i], texts[j], target_length)
|
121 |
+
mixed_sounds.append(new_sound)
|
122 |
+
mixed_captions.append(new_caption)
|
123 |
+
|
124 |
+
waveform = torch.tensor(np.concatenate(mixed_sounds, 0))
|
125 |
+
waveform = waveform / torch.max(torch.abs(waveform))
|
126 |
+
waveform = 0.5 * waveform
|
127 |
+
|
128 |
+
return waveform, mixed_captions
|
129 |
+
|
130 |
+
|
131 |
+
def augment_wav_to_fbank(paths, texts, num_items=4, target_length=1024, fn_STFT=None):
|
132 |
+
assert fn_STFT is not None
|
133 |
+
|
134 |
+
waveform, captions = augment(paths, texts)
|
135 |
+
fbank, log_magnitudes_stft, energy = get_mel_from_wav(waveform, fn_STFT)
|
136 |
+
fbank = fbank.transpose(1, 2)
|
137 |
+
log_magnitudes_stft = log_magnitudes_stft.transpose(1, 2)
|
138 |
+
|
139 |
+
fbank, log_magnitudes_stft = _pad_spec(fbank, target_length), _pad_spec(
|
140 |
+
log_magnitudes_stft, target_length
|
141 |
+
)
|
142 |
+
|
143 |
return fbank, log_magnitudes_stft, waveform, captions
|
codeclm/tokenizer/audio_tokenizer.py
CHANGED
@@ -208,9 +208,9 @@ class Flow1dVAESeparate(AudioTokenizer):
|
|
208 |
return codes_vocal, codes_bgm
|
209 |
|
210 |
@torch.no_grad()
|
211 |
-
def decode(self, codes: torch.Tensor, prompt_vocal = None, prompt_bgm = None):
|
212 |
wav = self.model.code2sound(codes, prompt_vocal=prompt_vocal, prompt_bgm=prompt_bgm, guidance_scale=1.5,
|
213 |
-
num_steps=50, disable_progress=False) # [B,N,T] -> [B,T]
|
214 |
return wav[None]
|
215 |
|
216 |
|
|
|
208 |
return codes_vocal, codes_bgm
|
209 |
|
210 |
@torch.no_grad()
|
211 |
+
def decode(self, codes: torch.Tensor, prompt_vocal = None, prompt_bgm = None, chunked=False):
|
212 |
wav = self.model.code2sound(codes, prompt_vocal=prompt_vocal, prompt_bgm=prompt_bgm, guidance_scale=1.5,
|
213 |
+
num_steps=50, disable_progress=False, chunked=chunked) # [B,N,T] -> [B,T]
|
214 |
return wav[None]
|
215 |
|
216 |
|
generate_lowmem.py
ADDED
@@ -0,0 +1,240 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import os
|
3 |
+
|
4 |
+
import time
|
5 |
+
import json
|
6 |
+
import torch
|
7 |
+
import torchaudio
|
8 |
+
import numpy as np
|
9 |
+
from omegaconf import OmegaConf
|
10 |
+
from codeclm.models import builders
|
11 |
+
|
12 |
+
from codeclm.trainer.codec_song_pl import CodecLM_PL
|
13 |
+
from codeclm.models import CodecLM
|
14 |
+
from third_party.demucs.models.pretrained import get_model_from_yaml
|
15 |
+
|
16 |
+
auto_prompt_type = ['Pop', 'R&B', 'Dance', 'Jazz', 'Folk', 'Rock', 'Chinese Style', 'Chinese Tradition', 'Metal', 'Reggae', 'Chinese Opera', 'Auto']
|
17 |
+
|
18 |
+
class Separator:
|
19 |
+
def __init__(self, dm_model_path='third_party/demucs/ckpt/htdemucs.pth', dm_config_path='third_party/demucs/ckpt/htdemucs.yaml', gpu_id=0) -> None:
|
20 |
+
if torch.cuda.is_available() and gpu_id < torch.cuda.device_count():
|
21 |
+
self.device = torch.device(f"cuda:{gpu_id}")
|
22 |
+
else:
|
23 |
+
self.device = torch.device("cpu")
|
24 |
+
self.demucs_model = self.init_demucs_model(dm_model_path, dm_config_path)
|
25 |
+
|
26 |
+
def init_demucs_model(self, model_path, config_path):
|
27 |
+
model = get_model_from_yaml(config_path, model_path)
|
28 |
+
model.to(self.device)
|
29 |
+
model.eval()
|
30 |
+
return model
|
31 |
+
|
32 |
+
def load_audio(self, f):
|
33 |
+
a, fs = torchaudio.load(f)
|
34 |
+
if (fs != 48000):
|
35 |
+
a = torchaudio.functional.resample(a, fs, 48000)
|
36 |
+
if a.shape[-1] >= 48000*10:
|
37 |
+
a = a[..., :48000*10]
|
38 |
+
else:
|
39 |
+
a = torch.cat([a, a], -1)
|
40 |
+
return a[:, 0:48000*10]
|
41 |
+
|
42 |
+
def run(self, audio_path, output_dir='tmp', ext=".flac"):
|
43 |
+
os.makedirs(output_dir, exist_ok=True)
|
44 |
+
name, _ = os.path.splitext(os.path.split(audio_path)[-1])
|
45 |
+
output_paths = []
|
46 |
+
|
47 |
+
for stem in self.demucs_model.sources:
|
48 |
+
output_path = os.path.join(output_dir, f"{name}_{stem}{ext}")
|
49 |
+
if os.path.exists(output_path):
|
50 |
+
output_paths.append(output_path)
|
51 |
+
if len(output_paths) == 1: # 4
|
52 |
+
vocal_path = output_paths[0]
|
53 |
+
else:
|
54 |
+
drums_path, bass_path, other_path, vocal_path = self.demucs_model.separate(audio_path, output_dir, device=self.device)
|
55 |
+
for path in [drums_path, bass_path, other_path]:
|
56 |
+
os.remove(path)
|
57 |
+
full_audio = self.load_audio(audio_path)
|
58 |
+
vocal_audio = self.load_audio(vocal_path)
|
59 |
+
bgm_audio = full_audio - vocal_audio
|
60 |
+
return full_audio, vocal_audio, bgm_audio
|
61 |
+
|
62 |
+
|
63 |
+
|
64 |
+
if __name__ == "__main__":
|
65 |
+
torch.backends.cudnn.enabled = False
|
66 |
+
OmegaConf.register_new_resolver("eval", lambda x: eval(x))
|
67 |
+
OmegaConf.register_new_resolver("concat", lambda *x: [xxx for xx in x for xxx in xx])
|
68 |
+
OmegaConf.register_new_resolver("get_fname", lambda: os.path.splitext(os.path.basename(sys.argv[1]))[0])
|
69 |
+
OmegaConf.register_new_resolver("load_yaml", lambda x: list(OmegaConf.load(x)))
|
70 |
+
np.random.seed(int(time.time()))
|
71 |
+
ckpt_path = sys.argv[1]
|
72 |
+
input_jsonl = sys.argv[2]
|
73 |
+
save_dir = sys.argv[3]
|
74 |
+
cfg_path = os.path.join(ckpt_path, 'config.yaml')
|
75 |
+
ckpt_path = os.path.join(ckpt_path, 'model.pt')
|
76 |
+
cfg = OmegaConf.load(cfg_path)
|
77 |
+
cfg.mode = 'inference'
|
78 |
+
max_duration = cfg.max_dur
|
79 |
+
|
80 |
+
separator = Separator()
|
81 |
+
auto_prompt = torch.load('ckpt/prompt.pt')
|
82 |
+
audio_tokenizer = builders.get_audio_tokenizer_model(cfg.audio_tokenizer_checkpoint, cfg)
|
83 |
+
if "audio_tokenizer_checkpoint_sep" in cfg.keys():
|
84 |
+
seperate_tokenizer = builders.get_audio_tokenizer_model(cfg.audio_tokenizer_checkpoint_sep, cfg)
|
85 |
+
else:
|
86 |
+
seperate_tokenizer = None
|
87 |
+
audio_tokenizer = audio_tokenizer.eval().cuda()
|
88 |
+
if seperate_tokenizer is not None:
|
89 |
+
seperate_tokenizer = seperate_tokenizer.eval().cuda()
|
90 |
+
|
91 |
+
merge_prompt = [item for sublist in auto_prompt.values() for item in sublist]
|
92 |
+
with open(input_jsonl, "r") as fp:
|
93 |
+
lines = fp.readlines()
|
94 |
+
new_items = []
|
95 |
+
for line in lines:
|
96 |
+
item = json.loads(line)
|
97 |
+
target_wav_name = f"{save_dir}/audios/{item['idx']}.flac"
|
98 |
+
# get prompt audio
|
99 |
+
if "prompt_audio_path" in item:
|
100 |
+
assert os.path.exists(item['prompt_audio_path']), f"prompt_audio_path {item['prompt_audio_path']} not found"
|
101 |
+
assert 'auto_prompt_audio_type' not in item, f"auto_prompt_audio_type and prompt_audio_path cannot be used together"
|
102 |
+
pmt_wav, vocal_wav, bgm_wav = separator.run(item['prompt_audio_path'])
|
103 |
+
item['raw_pmt_wav'] = pmt_wav
|
104 |
+
item['raw_vocal_wav'] = vocal_wav
|
105 |
+
item['raw_bgm_wav'] = bgm_wav
|
106 |
+
if pmt_wav.dim() == 2:
|
107 |
+
pmt_wav = pmt_wav[None]
|
108 |
+
if pmt_wav.dim() != 3:
|
109 |
+
raise ValueError("Melody wavs should have a shape [B, C, T].")
|
110 |
+
pmt_wav = list(pmt_wav)
|
111 |
+
if vocal_wav.dim() == 2:
|
112 |
+
vocal_wav = vocal_wav[None]
|
113 |
+
if vocal_wav.dim() != 3:
|
114 |
+
raise ValueError("Vocal wavs should have a shape [B, C, T].")
|
115 |
+
vocal_wav = list(vocal_wav)
|
116 |
+
if bgm_wav.dim() == 2:
|
117 |
+
bgm_wav = bgm_wav[None]
|
118 |
+
if bgm_wav.dim() != 3:
|
119 |
+
raise ValueError("BGM wavs should have a shape [B, C, T].")
|
120 |
+
bgm_wav = list(bgm_wav)
|
121 |
+
if type(pmt_wav) == list:
|
122 |
+
pmt_wav = torch.stack(pmt_wav, dim=0)
|
123 |
+
if type(vocal_wav) == list:
|
124 |
+
vocal_wav = torch.stack(vocal_wav, dim=0)
|
125 |
+
if type(bgm_wav) == list:
|
126 |
+
bgm_wav = torch.stack(bgm_wav, dim=0)
|
127 |
+
pmt_wav = pmt_wav.cuda()
|
128 |
+
vocal_wav = vocal_wav.cuda()
|
129 |
+
bgm_wav = bgm_wav.cuda()
|
130 |
+
pmt_wav, _ = audio_tokenizer.encode(pmt_wav)
|
131 |
+
vocal_wav, bgm_wav = seperate_tokenizer.encode(vocal_wav, bgm_wav)
|
132 |
+
melody_is_wav = False
|
133 |
+
elif "auto_prompt_audio_type" in item:
|
134 |
+
assert item["auto_prompt_audio_type"] in auto_prompt_type, f"auto_prompt_audio_type {item['auto_prompt_audio_type']} not found"
|
135 |
+
if item["auto_prompt_audio_type"] == "Auto":
|
136 |
+
prompt_token = merge_prompt[np.random.randint(0, len(merge_prompt))]
|
137 |
+
else:
|
138 |
+
prompt_token = auto_prompt[item["auto_prompt_audio_type"]][np.random.randint(0, len(auto_prompt[item["auto_prompt_audio_type"]]))]
|
139 |
+
pmt_wav = prompt_token[:,[0],:]
|
140 |
+
vocal_wav = prompt_token[:,[1],:]
|
141 |
+
bgm_wav = prompt_token[:,[2],:]
|
142 |
+
melody_is_wav = False
|
143 |
+
else:
|
144 |
+
pmt_wav = None
|
145 |
+
vocal_wav = None
|
146 |
+
bgm_wav = None
|
147 |
+
melody_is_wav = True
|
148 |
+
item['pmt_wav'] = pmt_wav
|
149 |
+
item['vocal_wav'] = vocal_wav
|
150 |
+
item['bgm_wav'] = bgm_wav
|
151 |
+
item['melody_is_wav'] = melody_is_wav
|
152 |
+
item["idx"] = f"{item['idx']}"
|
153 |
+
item["wav_path"] = target_wav_name
|
154 |
+
new_items.append(item)
|
155 |
+
|
156 |
+
del audio_tokenizer
|
157 |
+
del seperate_tokenizer
|
158 |
+
del separator
|
159 |
+
|
160 |
+
# Define model or load pretrained model
|
161 |
+
model_light = CodecLM_PL(cfg, ckpt_path)
|
162 |
+
model_light = model_light.eval()
|
163 |
+
model_light.audiolm.cfg = cfg
|
164 |
+
model = CodecLM(name = "tmp",
|
165 |
+
lm = model_light.audiolm,
|
166 |
+
audiotokenizer = None,
|
167 |
+
max_duration = max_duration,
|
168 |
+
seperate_tokenizer = None,
|
169 |
+
)
|
170 |
+
del model_light
|
171 |
+
model.lm = model.lm.cuda().to(torch.float16)
|
172 |
+
|
173 |
+
cfg_coef = 1.5 #25
|
174 |
+
temp = 0.9
|
175 |
+
top_k = 50
|
176 |
+
top_p = 0.0
|
177 |
+
record_tokens = True
|
178 |
+
record_window = 50
|
179 |
+
|
180 |
+
model.set_generation_params(duration=max_duration, extend_stride=5, temperature=temp, cfg_coef=cfg_coef,
|
181 |
+
top_k=top_k, top_p=top_p, record_tokens=record_tokens, record_window=record_window)
|
182 |
+
os.makedirs(save_dir, exist_ok=True)
|
183 |
+
os.makedirs(save_dir + "/audios", exist_ok=True)
|
184 |
+
os.makedirs(save_dir + "/jsonl", exist_ok=True)
|
185 |
+
|
186 |
+
|
187 |
+
for item in new_items:
|
188 |
+
lyric = item["gt_lyric"]
|
189 |
+
descriptions = item["descriptions"] if "descriptions" in item else None
|
190 |
+
pmt_wav = item['pmt_wav']
|
191 |
+
vocal_wav = item['vocal_wav']
|
192 |
+
bgm_wav = item['bgm_wav']
|
193 |
+
melody_is_wav = item['melody_is_wav']
|
194 |
+
|
195 |
+
generate_inp = {
|
196 |
+
'lyrics': [lyric.replace(" ", " ")],
|
197 |
+
'descriptions': [descriptions],
|
198 |
+
'melody_wavs': pmt_wav,
|
199 |
+
'vocal_wavs': vocal_wav,
|
200 |
+
'bgm_wavs': bgm_wav,
|
201 |
+
'melody_is_wav': melody_is_wav,
|
202 |
+
}
|
203 |
+
with torch.autocast(device_type="cuda", dtype=torch.float16):
|
204 |
+
tokens = model.generate(**generate_inp, return_tokens=True)
|
205 |
+
item['tokens'] = tokens
|
206 |
+
|
207 |
+
del model
|
208 |
+
torch.cuda.empty_cache()
|
209 |
+
|
210 |
+
|
211 |
+
seperate_tokenizer = builders.get_audio_tokenizer_model(cfg.audio_tokenizer_checkpoint_sep, cfg)
|
212 |
+
seperate_tokenizer = seperate_tokenizer.eval().cuda()
|
213 |
+
|
214 |
+
model = CodecLM(name = "tmp",
|
215 |
+
lm = None,
|
216 |
+
audiotokenizer = None,
|
217 |
+
max_duration = max_duration,
|
218 |
+
seperate_tokenizer = seperate_tokenizer,
|
219 |
+
)
|
220 |
+
for item in new_items:
|
221 |
+
with torch.no_grad():
|
222 |
+
if 'raw_pmt_wav' in item:
|
223 |
+
wav_seperate = model.generate_audio(item['tokens'], item['raw_pmt_wav'], item['raw_vocal_wav'], item['raw_bgm_wav'], chunked=True)
|
224 |
+
del item['raw_pmt_wav']
|
225 |
+
del item['raw_vocal_wav']
|
226 |
+
del item['raw_bgm_wav']
|
227 |
+
else:
|
228 |
+
wav_seperate = model.generate_audio(item['tokens'], chunked=True)
|
229 |
+
torchaudio.save(item['wav_path'], wav_seperate[0].cpu().float(), cfg.sample_rate)
|
230 |
+
del item['tokens']
|
231 |
+
del item['pmt_wav']
|
232 |
+
del item['vocal_wav']
|
233 |
+
del item['bgm_wav']
|
234 |
+
del item['melody_is_wav']
|
235 |
+
|
236 |
+
torch.cuda.empty_cache()
|
237 |
+
src_jsonl_name = os.path.split(input_jsonl)[-1]
|
238 |
+
with open(f"{save_dir}/jsonl/{src_jsonl_name}.jsonl", "w", encoding='utf-8') as fw:
|
239 |
+
for item in new_items:
|
240 |
+
fw.writelines(json.dumps(item, ensure_ascii=False)+"\n")
|
generate_lowmem.sh
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
export USER=root
|
2 |
+
export PYTHONDONTWRITEBYTECODE=1
|
3 |
+
export TRANSFORMERS_CACHE="$(pwd)/third_party/hub"
|
4 |
+
export NCCL_HOME=/usr/local/tccl
|
5 |
+
export PYTHONPATH="$(pwd)/codeclm/tokenizer/":"$(pwd)":"$(pwd)/codeclm/tokenizer/Flow1dVAE/":"$(pwd)/codeclm/tokenizer/":$PYTHONPATH
|
6 |
+
|
7 |
+
CKPT_PATH=$1
|
8 |
+
JSONL=$2
|
9 |
+
SAVE_DIR=$3
|
10 |
+
python3 generate_lowmem.py $CKPT_PATH $JSONL $SAVE_DIR
|
requirements.txt
CHANGED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
alias-free-torch>=0.0.6
|
2 |
+
descript-audio-codec>=1.0.0
|
3 |
+
diffusers==0.27.2
|
4 |
+
einops>=0.8.1
|
5 |
+
einops-exts==0.0.4
|
6 |
+
flashy>=0.0.2
|
7 |
+
huggingface-hub==0.25.2
|
8 |
+
julius>=0.2.7
|
9 |
+
k-diffusion==0.1.1
|
10 |
+
kaldiio>=2.18.1
|
11 |
+
lameenc>=1.8.1
|
12 |
+
librosa>=0.11.0
|
13 |
+
lightning>=2.5.2
|
14 |
+
ninja>=1.11.1.4
|
15 |
+
nnAudio>=0.3.3
|
16 |
+
openunmix>=1.3.0
|
17 |
+
peft==0.10.0
|
18 |
+
torch==2.6.0
|
19 |
+
torchaudio==2.6.0
|
20 |
+
torchvision==0.21.0
|
21 |
+
transformers==4.37.2
|
22 |
+
vector-quantize-pytorch>=1.22.17
|
23 |
+
wheel>=0.45.1
|
24 |
+
x-transformers>=2.3.25
|
requirements_nodeps.txt
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
fairseq==0.12.2
|
2 |
+
antlr4-python3-runtime==4.8
|
3 |
+
bitarray==3.4.3
|
4 |
+
cffi==1.17.1
|
5 |
+
colorama==0.4.6
|
6 |
+
cython==3.1.2
|
7 |
+
hydra-core==1.0.7
|
8 |
+
lxml==5.4.0
|
9 |
+
omegaconf==2.2.0
|
10 |
+
portalocker==3.2.0
|
11 |
+
pycparser==2.22
|
12 |
+
sacrebleu==2.5.1
|
13 |
+
tabulate==0.9.0
|
sample/lyrics.jsonl
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
{"idx": "sample_01_autoprompt", "gt_lyric": "[intro-short] ; [verse] 雪花舞动在无尽的天际.情缘如同雪花般轻轻逝去.希望与真挚.永不磨灭.你的忧虑.随风而逝 ; [chorus] 我怀抱着守护这片梦境.在这世界中寻找爱与虚幻.苦辣酸甜.我们一起品尝.在雪的光芒中.紧紧相拥 ; [inst-short] ; [verse] 雪花再次在风中飘扬.情愿如同雪花般消失无踪.希望与真挚.永不消失.在痛苦与喧嚣中.你找到解脱 ; [chorus] 我环绕着守护这片梦境.在这世界中感受爱与虚假.苦辣酸甜.我们一起分享.在白银的光芒中.我们同在 ; [outro-short]", "auto_prompt_audio_type": "Auto"}
|
2 |
{"idx": "sample_01_noprompt", "gt_lyric": "[intro-short] ; [verse] 雪花舞动在无尽的天际.情缘如同雪花般轻轻逝去.希望与真挚.永不磨灭.你的忧虑.随风而逝 ; [chorus] 我怀抱着守护这片梦境.在这世界中寻找爱与虚幻.苦辣酸甜.我们一起品尝.在雪的光芒中.紧紧相拥 ; [inst-short] ; [verse] 雪花再次在风中飘扬.情愿如同雪花般消失无踪.希望与真挚.永不消失.在痛苦与喧嚣中.你找到解脱 ; [chorus] 我环绕着守护这片梦境.在这世界中感受爱与虚假.苦辣酸甜.我们一起分享.在白银的光芒中.我们同在 ; [outro-short]"}
|
3 |
{"idx": "sample_01_textprompt", "descriptions": "female, dark, pop, sad, piano and drums, the bpm is 125.", "gt_lyric": "[intro-short] ; [verse] 雪花舞动在无尽的天际.情缘如同雪花般轻轻逝去.希望与真挚.永不磨灭.你的忧虑.随风而逝 ; [chorus] 我怀抱着守护这片梦境.在这世界中寻找爱与虚幻.苦辣酸甜.我们一起品尝.在雪的光芒中.紧紧相拥 ; [inst-short] ; [verse] 雪花再次在风中飘扬.情愿如同雪花般消失无踪.希望与真挚.永不消失.在痛苦与喧嚣中.你找到解脱 ; [chorus] 我环绕着守护这片梦境.在这世界中感受爱与虚假.苦辣酸甜.我们一起分享.在白银的光芒中.我们同在 ; [outro-short]"}
|
4 |
-
{"idx": "sample_01_audioprompt", "gt_lyric": "[intro-short] ; [verse] 雪花舞动在无尽的天际.情缘如同雪花般轻轻逝去.希望与真挚.永不磨灭.你的忧虑.随风而逝 ; [chorus] 我怀抱着守护这片梦境.在这世界中寻找爱与虚幻.苦辣酸甜.我们一起品尝.在雪的光芒中.紧紧相拥 ; [inst-short] ; [verse] 雪花再次在风中飘扬.情愿如同雪花般消失无踪.希望与真挚.永不消失.在痛苦与喧嚣中.你找到解脱 ; [chorus] 我环绕着守护这片梦境.在这世界中感受爱与虚假.苦辣酸甜.我们一起分享.在白银的光芒中.我们同在 ; [outro-short]", "prompt_audio_path": "
|
|
|
1 |
{"idx": "sample_01_autoprompt", "gt_lyric": "[intro-short] ; [verse] 雪花舞动在无尽的天际.情缘如同雪花般轻轻逝去.希望与真挚.永不磨灭.你的忧虑.随风而逝 ; [chorus] 我怀抱着守护这片梦境.在这世界中寻找爱与虚幻.苦辣酸甜.我们一起品尝.在雪的光芒中.紧紧相拥 ; [inst-short] ; [verse] 雪花再次在风中飘扬.情愿如同雪花般消失无踪.希望与真挚.永不消失.在痛苦与喧嚣中.你找到解脱 ; [chorus] 我环绕着守护这片梦境.在这世界中感受爱与虚假.苦辣酸甜.我们一起分享.在白银的光芒中.我们同在 ; [outro-short]", "auto_prompt_audio_type": "Auto"}
|
2 |
{"idx": "sample_01_noprompt", "gt_lyric": "[intro-short] ; [verse] 雪花舞动在无尽的天际.情缘如同雪花般轻轻逝去.希望与真挚.永不磨灭.你的忧虑.随风而逝 ; [chorus] 我怀抱着守护这片梦境.在这世界中寻找爱与虚幻.苦辣酸甜.我们一起品尝.在雪的光芒中.紧紧相拥 ; [inst-short] ; [verse] 雪花再次在风中飘扬.情愿如同雪花般消失无踪.希望与真挚.永不消失.在痛苦与喧嚣中.你找到解脱 ; [chorus] 我环绕着守护这片梦境.在这世界中感受爱与虚假.苦辣酸甜.我们一起分享.在白银的光芒中.我们同在 ; [outro-short]"}
|
3 |
{"idx": "sample_01_textprompt", "descriptions": "female, dark, pop, sad, piano and drums, the bpm is 125.", "gt_lyric": "[intro-short] ; [verse] 雪花舞动在无尽的天际.情缘如同雪花般轻轻逝去.希望与真挚.永不磨灭.你的忧虑.随风而逝 ; [chorus] 我怀抱着守护这片梦境.在这世界中寻找爱与虚幻.苦辣酸甜.我们一起品尝.在雪的光芒中.紧紧相拥 ; [inst-short] ; [verse] 雪花再次在风中飘扬.情愿如同雪花般消失无踪.希望与真挚.永不消失.在痛苦与喧嚣中.你找到解脱 ; [chorus] 我环绕着守护这片梦境.在这世界中感受爱与虚假.苦辣酸甜.我们一起分享.在白银的光芒中.我们同在 ; [outro-short]"}
|
4 |
+
{"idx": "sample_01_audioprompt", "gt_lyric": "[intro-short] ; [verse] 雪花舞动在无尽的天际.情缘如同雪花般轻轻逝去.希望与真挚.永不磨灭.你的忧虑.随风而逝 ; [chorus] 我怀抱着守护这片梦境.在这世界中寻找爱与虚幻.苦辣酸甜.我们一起品尝.在雪的光芒中.紧紧相拥 ; [inst-short] ; [verse] 雪花再次在风中飘扬.情愿如同雪花般消失无踪.希望与真挚.永不消失.在痛苦与喧嚣中.你找到解脱 ; [chorus] 我环绕着守护这片梦境.在这世界中感受爱与虚假.苦辣酸甜.我们一起分享.在白银的光芒中.我们同在 ; [outro-short]", "prompt_audio_path": "input/sample_prompt_audio.wav"}
|
tools/gradio/app.py
ADDED
@@ -0,0 +1,236 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import gradio as gr
|
3 |
+
import json
|
4 |
+
from datetime import datetime
|
5 |
+
import yaml
|
6 |
+
import time
|
7 |
+
import re
|
8 |
+
import os.path as op
|
9 |
+
from levo_inference_lowmem import LeVoInference
|
10 |
+
|
11 |
+
EXAMPLE_LYRICS = """
|
12 |
+
[intro-short]
|
13 |
+
|
14 |
+
[verse]
|
15 |
+
夜晚的街灯闪烁
|
16 |
+
我漫步在熟悉的角落
|
17 |
+
回忆像潮水般涌来
|
18 |
+
你的笑容如此清晰
|
19 |
+
在心头无法抹去
|
20 |
+
那些曾经的甜蜜
|
21 |
+
如今只剩我独自回忆
|
22 |
+
|
23 |
+
[verse]
|
24 |
+
手机屏幕亮起
|
25 |
+
是你发来的消息
|
26 |
+
简单的几个字
|
27 |
+
却让我泪流满面
|
28 |
+
曾经的拥抱温暖
|
29 |
+
如今却变得遥远
|
30 |
+
我多想回到从前
|
31 |
+
重新拥有你的陪伴
|
32 |
+
|
33 |
+
[chorus]
|
34 |
+
回忆的温度还在
|
35 |
+
你却已不在
|
36 |
+
我的心被爱填满
|
37 |
+
却又被思念刺痛
|
38 |
+
音乐的节奏奏响
|
39 |
+
我的心却在流浪
|
40 |
+
没有你的日子
|
41 |
+
我该如何继续向前
|
42 |
+
|
43 |
+
[outro-short]
|
44 |
+
""".strip()
|
45 |
+
|
46 |
+
APP_DIR = op.dirname(op.dirname(op.dirname(op.abspath(__file__))))
|
47 |
+
MODEL = LeVoInference(sys.argv[1])
|
48 |
+
with open(op.join(APP_DIR, 'conf/vocab.yaml'), 'r', encoding='utf-8') as file:
|
49 |
+
STRUCTS = yaml.safe_load(file)
|
50 |
+
|
51 |
+
|
52 |
+
def generate_song(lyric, description=None, prompt_audio=None, genre=None, cfg_coef=None, temperature=None, top_k=None, progress=gr.Progress(track_tqdm=True)):
|
53 |
+
global MODEL
|
54 |
+
global STRUCTS
|
55 |
+
params = {'cfg_coef':cfg_coef, 'temperature':temperature, 'top_k':top_k}
|
56 |
+
params = {k:v for k,v in params.items() if v is not None}
|
57 |
+
vocal_structs = ['[verse]', '[chorus]', '[bridge]']
|
58 |
+
sample_rate = MODEL.cfg.sample_rate
|
59 |
+
|
60 |
+
# format lyric
|
61 |
+
lyric = lyric.replace("[intro]", "[intro-short]").replace("[inst]", "[inst-short]").replace("[outro]", "[outro-short]")
|
62 |
+
paragraphs = [p.strip() for p in lyric.strip().split('\n\n') if p.strip()]
|
63 |
+
if len(paragraphs) < 1:
|
64 |
+
return None, json.dumps("Lyrics can not be left blank")
|
65 |
+
paragraphs_norm = []
|
66 |
+
vocal_flag = False
|
67 |
+
for para in paragraphs:
|
68 |
+
lines = para.splitlines()
|
69 |
+
struct_tag = lines[0].strip().lower()
|
70 |
+
if struct_tag not in STRUCTS:
|
71 |
+
return None, json.dumps(f"Segments should start with a structure tag in {STRUCTS}")
|
72 |
+
if struct_tag in vocal_structs:
|
73 |
+
vocal_flag = True
|
74 |
+
if len(lines) < 2 or not [line.strip() for line in lines[1:] if line.strip()]:
|
75 |
+
return None, json.dumps("The following segments require lyrics: [verse], [chorus], [bridge]")
|
76 |
+
else:
|
77 |
+
new_para_list = []
|
78 |
+
for line in lines[1:]:
|
79 |
+
new_para_list.append(re.sub(r"[^\w\s\[\]\-\u4e00-\u9fff\u3040-\u309f\u30a0-\u30ff\uac00-\ud7af\u00c0-\u017f]", "", line))
|
80 |
+
new_para_str = f"{struct_tag} {'.'.join(new_para_list)}"
|
81 |
+
else:
|
82 |
+
if len(lines) > 1:
|
83 |
+
return None, json.dumps("The following segments should not contain lyrics: [intro], [intro-short], [intro-medium], [inst], [inst-short], [inst-medium], [outro], [outro-short], [outro-medium]")
|
84 |
+
else:
|
85 |
+
new_para_str = struct_tag
|
86 |
+
paragraphs_norm.append(new_para_str)
|
87 |
+
if not vocal_flag:
|
88 |
+
return None, json.dumps(f"The lyric must contain at least one of the following structures: {vocal_structs}")
|
89 |
+
lyric_norm = " ; ".join(paragraphs_norm)
|
90 |
+
|
91 |
+
# format prompt
|
92 |
+
if prompt_audio is not None:
|
93 |
+
genre = None
|
94 |
+
description = None
|
95 |
+
elif description is not None and description != "":
|
96 |
+
genre = None
|
97 |
+
|
98 |
+
progress(0.0, "Start Generation")
|
99 |
+
start = time.time()
|
100 |
+
|
101 |
+
audio_data = MODEL(lyric_norm, description, prompt_audio, genre, op.join(APP_DIR, "ckpt/prompt.pt"), params).cpu().permute(1, 0).float().numpy()
|
102 |
+
|
103 |
+
end = time.time()
|
104 |
+
|
105 |
+
# 创建输入配置的JSON
|
106 |
+
input_config = {
|
107 |
+
"lyric": lyric_norm,
|
108 |
+
"genre": genre,
|
109 |
+
"prompt_audio": prompt_audio,
|
110 |
+
"description": description,
|
111 |
+
"params": params,
|
112 |
+
"inference_duration": end - start,
|
113 |
+
"timestamp": datetime.now().isoformat(),
|
114 |
+
}
|
115 |
+
|
116 |
+
return (sample_rate, audio_data), json.dumps(input_config, indent=2)
|
117 |
+
|
118 |
+
|
119 |
+
# 创建Gradio界面
|
120 |
+
with gr.Blocks(title="SongGeneration Demo Space") as demo:
|
121 |
+
gr.Markdown("# 🎵 SongGeneration Demo Space")
|
122 |
+
gr.Markdown("Demo interface for the song generation model. Provide a lyrics, and optionally an audio or text prompt, to generate a custom song.")
|
123 |
+
|
124 |
+
with gr.Row():
|
125 |
+
with gr.Column():
|
126 |
+
lyric = gr.Textbox(
|
127 |
+
label="Lyrics",
|
128 |
+
lines=5,
|
129 |
+
max_lines=15,
|
130 |
+
value=EXAMPLE_LYRICS,
|
131 |
+
info="Each paragraph represents a segment starting with a structure tag and ending with a blank line, each line is a sentence without punctuation, segments [intro], [inst], [outro] should not contain lyrics, while [verse], [chorus], and [bridge] require lyrics.",
|
132 |
+
placeholder="""Lyric Format
|
133 |
+
'''
|
134 |
+
[structure tag]
|
135 |
+
lyrics
|
136 |
+
|
137 |
+
[structure tag]
|
138 |
+
lyrics
|
139 |
+
'''
|
140 |
+
1. One paragraph represents one segments, starting with a structure tag and ending with a blank line
|
141 |
+
2. One line represents one sentence, punctuation is not recommended inside the sentence
|
142 |
+
3. The following segments should not contain lyrics: [intro-short], [intro-medium], [inst-short], [inst-medium], [outro-short], [outro-medium]
|
143 |
+
4. The following segments require lyrics: [verse], [chorus], [bridge]
|
144 |
+
"""
|
145 |
+
)
|
146 |
+
|
147 |
+
with gr.Tabs(elem_id="extra-tabs"):
|
148 |
+
with gr.Tab("Genre Select"):
|
149 |
+
genre = gr.Radio(
|
150 |
+
choices=["Pop", "R&B", "Dance", "Jazz", "Folk", "Rock", "Chinese Style", "Chinese Tradition", "Metal", "Reggae", "Chinese Opera", "Auto"],
|
151 |
+
label="Genre Select(Optional)",
|
152 |
+
value="Pop",
|
153 |
+
interactive=True,
|
154 |
+
elem_id="single-select-radio"
|
155 |
+
)
|
156 |
+
with gr.Tab("Audio Prompt"):
|
157 |
+
prompt_audio = gr.Audio(
|
158 |
+
label="Prompt Audio (Optional)",
|
159 |
+
type="filepath",
|
160 |
+
elem_id="audio-prompt"
|
161 |
+
)
|
162 |
+
with gr.Tab("Text Prompt"):
|
163 |
+
gr.Markdown("For detailed usage, please refer to [here](https://github.com/tencent-ailab/SongGeneration?tab=readme-ov-file#-description-input-format)")
|
164 |
+
description = gr.Textbox(
|
165 |
+
label="Song Description (Optional)",
|
166 |
+
info="Describe the gender, timbre, genre, emotion, instrument and bpm of the song. Only English is supported currently.",
|
167 |
+
placeholder="female, dark, pop, sad, piano and drums, the bpm is 125.",
|
168 |
+
lines=1,
|
169 |
+
max_lines=2
|
170 |
+
)
|
171 |
+
|
172 |
+
with gr.Accordion("Advanced Config", open=False):
|
173 |
+
cfg_coef = gr.Slider(
|
174 |
+
label="CFG Coefficient",
|
175 |
+
minimum=0.1,
|
176 |
+
maximum=3.0,
|
177 |
+
step=0.1,
|
178 |
+
value=1.5,
|
179 |
+
interactive=True,
|
180 |
+
elem_id="cfg-coef",
|
181 |
+
)
|
182 |
+
temperature = gr.Slider(
|
183 |
+
label="Temperature",
|
184 |
+
minimum=0.1,
|
185 |
+
maximum=2.0,
|
186 |
+
step=0.1,
|
187 |
+
value=0.9,
|
188 |
+
interactive=True,
|
189 |
+
elem_id="temperature",
|
190 |
+
)
|
191 |
+
top_k = gr.Slider(
|
192 |
+
label="Top-K",
|
193 |
+
minimum=1,
|
194 |
+
maximum=100,
|
195 |
+
step=1,
|
196 |
+
value=50,
|
197 |
+
interactive=True,
|
198 |
+
elem_id="top_k",
|
199 |
+
)
|
200 |
+
generate_btn = gr.Button("Generate Song", variant="primary")
|
201 |
+
|
202 |
+
with gr.Column():
|
203 |
+
output_audio = gr.Audio(label="Generated Song", type="numpy")
|
204 |
+
output_json = gr.JSON(label="Generated Info")
|
205 |
+
|
206 |
+
# # 示例按钮
|
207 |
+
# examples = gr.Examples(
|
208 |
+
# examples=[
|
209 |
+
# ["male, bright, rock, happy, electric guitar and drums, the bpm is 150."],
|
210 |
+
# ["female, warm, jazz, romantic, synthesizer and piano, the bpm is 100."]
|
211 |
+
# ],
|
212 |
+
# inputs=[description],
|
213 |
+
# label="Text Prompt examples"
|
214 |
+
# )
|
215 |
+
|
216 |
+
# examples = gr.Examples(
|
217 |
+
# examples=[
|
218 |
+
# "[intro-medium]\n\n[verse]\n在这个疯狂的世界里\n谁不渴望一点改变\n在爱情面前\n我们都显得那么不安全\n你紧紧抱着我\n告诉我再靠近一点\n别让这璀璨的夜晚白白浪费\n我那迷茫的眼睛\n看不见未来的路\n在情感消散之前\n我们对爱的渴望永不熄灭\n你给我留下一句誓言\n想知道我们的爱是否能持续到永远\n[chorus]\n\n约定在那最后的夜晚\n不管命运如何摆布\n我们的心是否依然如初\n我会穿上红衬衫\n带着摇滚的激情\n回到我们初遇的地方\n约定在那最后的夜晚\n就算全世界都变了样\n我依然坚守诺言\n铭记这一天\n你永远是我心中的爱恋\n\n[outro-medium]\n",
|
219 |
+
# "[intro-short]\n\n[verse]\nThrough emerald canyons where fireflies dwell\nCerulean berries kiss morning's first swell\nCrystalline dew crowns each Vitamin Dawn's confection dissolves slowly on me\nAmbrosia breezes through honeycomb vines\nNature's own candy in Fibonacci lines\n[chorus] Blueberry fruit so sweet\n takes you higher\n can't be beat\n In your lungs\n it starts to swell\n You're under its spell\n [verse] Resin of sunlight in candied retreat\nMarmalade moonbeams melt under bare feet\nNectar spirals bloom chloroplast champagne\nPhotosynthesis sings through my veins\nChlorophyll rhythms pulse warm in my blood\nThe forest's green pharmacy floods every bud[chorus] Blueberry fruit so sweet\n takes you higher\n can't be beat\n In your lungs\n it starts to swell\n You're under its spell\n feel the buzz\n ride the wave\n Limey me\n blueberry\n your mind's enslaved\n In the haze\n lose all time\n floating free\n feeling fine\n Blueberry\n fruit so sweet\n takes you higher\n can't be beat\n In your lungs\n it starts to swell\n cry\n You're under its spell\n\n[outro-short]\n",
|
220 |
+
# ],
|
221 |
+
# inputs=[lyric],
|
222 |
+
# label="Lyrics examples",
|
223 |
+
# )
|
224 |
+
|
225 |
+
# 生成按钮点击事件
|
226 |
+
generate_btn.click(
|
227 |
+
fn=generate_song,
|
228 |
+
inputs=[lyric, description, prompt_audio, genre, cfg_coef, temperature, top_k],
|
229 |
+
outputs=[output_audio, output_json]
|
230 |
+
)
|
231 |
+
|
232 |
+
|
233 |
+
# 启动应用
|
234 |
+
if __name__ == "__main__":
|
235 |
+
demo.launch(server_name="0.0.0.0", server_port=8081)
|
236 |
+
|
tools/gradio/levo_inference.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
|
4 |
+
import torch
|
5 |
+
|
6 |
+
import json
|
7 |
+
import numpy as np
|
8 |
+
from omegaconf import OmegaConf
|
9 |
+
|
10 |
+
from codeclm.trainer.codec_song_pl import CodecLM_PL
|
11 |
+
from codeclm.models import CodecLM
|
12 |
+
|
13 |
+
from separator import Separator
|
14 |
+
|
15 |
+
|
16 |
+
class LeVoInference(torch.nn.Module):
|
17 |
+
def __init__(self, ckpt_path):
|
18 |
+
super().__init__()
|
19 |
+
|
20 |
+
torch.backends.cudnn.enabled = False
|
21 |
+
OmegaConf.register_new_resolver("eval", lambda x: eval(x))
|
22 |
+
OmegaConf.register_new_resolver("concat", lambda *x: [xxx for xx in x for xxx in xx])
|
23 |
+
OmegaConf.register_new_resolver("get_fname", lambda: 'default')
|
24 |
+
OmegaConf.register_new_resolver("load_yaml", lambda x: list(OmegaConf.load(x)))
|
25 |
+
|
26 |
+
cfg_path = os.path.join(ckpt_path, 'config.yaml')
|
27 |
+
pt_path = os.path.join(ckpt_path, 'model.pt')
|
28 |
+
|
29 |
+
self.cfg = OmegaConf.load(cfg_path)
|
30 |
+
self.cfg.mode = 'inference'
|
31 |
+
self.max_duration = self.cfg.max_dur
|
32 |
+
|
33 |
+
# Define model or load pretrained model
|
34 |
+
model_light = CodecLM_PL(self.cfg, pt_path)
|
35 |
+
|
36 |
+
model_light = model_light.eval().cuda()
|
37 |
+
model_light.audiolm.cfg = self.cfg
|
38 |
+
|
39 |
+
self.model_lm = model_light.audiolm
|
40 |
+
self.model_audio_tokenizer = model_light.audio_tokenizer
|
41 |
+
self.model_seperate_tokenizer = model_light.seperate_tokenizer
|
42 |
+
|
43 |
+
self.model = CodecLM(name = "tmp",
|
44 |
+
lm = self.model_lm,
|
45 |
+
audiotokenizer = self.model_audio_tokenizer,
|
46 |
+
max_duration = self.max_duration,
|
47 |
+
seperate_tokenizer = self.model_seperate_tokenizer,
|
48 |
+
)
|
49 |
+
self.separator = Separator()
|
50 |
+
|
51 |
+
|
52 |
+
self.default_params = dict(
|
53 |
+
cfg_coef = 1.5,
|
54 |
+
temperature = 1.0,
|
55 |
+
top_k = 50,
|
56 |
+
top_p = 0.0,
|
57 |
+
record_tokens = True,
|
58 |
+
record_window = 50,
|
59 |
+
extend_stride = 5,
|
60 |
+
duration = self.max_duration,
|
61 |
+
)
|
62 |
+
|
63 |
+
self.model.set_generation_params(**self.default_params)
|
64 |
+
|
65 |
+
def forward(self, lyric: str, description: str = None, prompt_audio_path: os.PathLike = None, genre: str = None, auto_prompt_path: os.PathLike = None, params = dict()):
|
66 |
+
params = {**self.default_params, **params}
|
67 |
+
self.model.set_generation_params(**params)
|
68 |
+
|
69 |
+
if prompt_audio_path is not None and os.path.exists(prompt_audio_path):
|
70 |
+
pmt_wav, vocal_wav, bgm_wav = self.separator.run(prompt_audio_path)
|
71 |
+
melody_is_wav = True
|
72 |
+
elif genre is not None and auto_prompt_path is not None:
|
73 |
+
auto_prompt = torch.load(auto_prompt_path)
|
74 |
+
merge_prompt = [item for sublist in auto_prompt.values() for item in sublist]
|
75 |
+
if genre == "Auto":
|
76 |
+
prompt_token = merge_prompt[np.random.randint(0, len(merge_prompt))]
|
77 |
+
else:
|
78 |
+
prompt_token = auto_prompt[genre][np.random.randint(0, len(auto_prompt[genre]))]
|
79 |
+
pmt_wav = prompt_token[:,[0],:]
|
80 |
+
vocal_wav = prompt_token[:,[1],:]
|
81 |
+
bgm_wav = prompt_token[:,[2],:]
|
82 |
+
melody_is_wav = False
|
83 |
+
else:
|
84 |
+
pmt_wav = None
|
85 |
+
vocal_wav = None
|
86 |
+
bgm_wav = None
|
87 |
+
melody_is_wav = True
|
88 |
+
|
89 |
+
generate_inp = {
|
90 |
+
'lyrics': [lyric.replace(" ", " ")],
|
91 |
+
'descriptions': [description],
|
92 |
+
'melody_wavs': pmt_wav,
|
93 |
+
'vocal_wavs': vocal_wav,
|
94 |
+
'bgm_wavs': bgm_wav,
|
95 |
+
'melody_is_wav': melody_is_wav,
|
96 |
+
}
|
97 |
+
|
98 |
+
with torch.autocast(device_type="cuda", dtype=torch.float16):
|
99 |
+
tokens = self.model.generate(**generate_inp, return_tokens=True)
|
100 |
+
|
101 |
+
if tokens.shape[-1] > 3000:
|
102 |
+
tokens = tokens[..., :3000]
|
103 |
+
|
104 |
+
with torch.no_grad():
|
105 |
+
if melody_is_wav:
|
106 |
+
wav_seperate = self.model.generate_audio(tokens, pmt_wav, vocal_wav, bgm_wav)
|
107 |
+
else:
|
108 |
+
wav_seperate = self.model.generate_audio(tokens)
|
109 |
+
|
110 |
+
return wav_seperate[0]
|
tools/gradio/levo_inference_lowmem.py
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
|
4 |
+
import torch
|
5 |
+
|
6 |
+
import json
|
7 |
+
import numpy as np
|
8 |
+
from omegaconf import OmegaConf
|
9 |
+
|
10 |
+
from codeclm.trainer.codec_song_pl import CodecLM_PL
|
11 |
+
from codeclm.models import CodecLM
|
12 |
+
from codeclm.models import builders
|
13 |
+
|
14 |
+
from separator import Separator
|
15 |
+
|
16 |
+
|
17 |
+
class LeVoInference(torch.nn.Module):
|
18 |
+
def __init__(self, ckpt_path):
|
19 |
+
super().__init__()
|
20 |
+
|
21 |
+
torch.backends.cudnn.enabled = False
|
22 |
+
OmegaConf.register_new_resolver("eval", lambda x: eval(x))
|
23 |
+
OmegaConf.register_new_resolver("concat", lambda *x: [xxx for xx in x for xxx in xx])
|
24 |
+
OmegaConf.register_new_resolver("get_fname", lambda: 'default')
|
25 |
+
OmegaConf.register_new_resolver("load_yaml", lambda x: list(OmegaConf.load(x)))
|
26 |
+
|
27 |
+
cfg_path = os.path.join(ckpt_path, 'config.yaml')
|
28 |
+
self.pt_path = os.path.join(ckpt_path, 'model.pt')
|
29 |
+
|
30 |
+
self.cfg = OmegaConf.load(cfg_path)
|
31 |
+
self.cfg.mode = 'inference'
|
32 |
+
self.max_duration = self.cfg.max_dur
|
33 |
+
|
34 |
+
self.default_params = dict(
|
35 |
+
top_p = 0.0,
|
36 |
+
record_tokens = True,
|
37 |
+
record_window = 50,
|
38 |
+
extend_stride = 5,
|
39 |
+
duration = self.max_duration,
|
40 |
+
)
|
41 |
+
|
42 |
+
|
43 |
+
def forward(self, lyric: str, description: str = None, prompt_audio_path: os.PathLike = None, genre: str = None, auto_prompt_path: os.PathLike = None, params = dict()):
|
44 |
+
if prompt_audio_path is not None and os.path.exists(prompt_audio_path):
|
45 |
+
separator = Separator()
|
46 |
+
audio_tokenizer = builders.get_audio_tokenizer_model(self.cfg.audio_tokenizer_checkpoint, self.cfg)
|
47 |
+
audio_tokenizer = audio_tokenizer.eval().cuda()
|
48 |
+
seperate_tokenizer = builders.get_audio_tokenizer_model(self.cfg.audio_tokenizer_checkpoint_sep, self.cfg)
|
49 |
+
seperate_tokenizer = seperate_tokenizer.eval().cuda()
|
50 |
+
pmt_wav, vocal_wav, bgm_wav = separator.run(prompt_audio_path)
|
51 |
+
pmt_wav = pmt_wav.cuda()
|
52 |
+
vocal_wav = vocal_wav.cuda()
|
53 |
+
bgm_wav = bgm_wav.cuda()
|
54 |
+
pmt_wav, _ = audio_tokenizer.encode(pmt_wav)
|
55 |
+
vocal_wav, bgm_wav = seperate_tokenizer.encode(vocal_wav, bgm_wav)
|
56 |
+
melody_is_wav = False
|
57 |
+
melody_is_wav = False
|
58 |
+
del audio_tokenizer
|
59 |
+
del seperate_tokenizer
|
60 |
+
del separator
|
61 |
+
elif genre is not None and auto_prompt_path is not None:
|
62 |
+
auto_prompt = torch.load(auto_prompt_path)
|
63 |
+
merge_prompt = [item for sublist in auto_prompt.values() for item in sublist]
|
64 |
+
if genre == "Auto":
|
65 |
+
prompt_token = merge_prompt[np.random.randint(0, len(merge_prompt))]
|
66 |
+
else:
|
67 |
+
prompt_token = auto_prompt[genre][np.random.randint(0, len(auto_prompt[genre]))]
|
68 |
+
pmt_wav = prompt_token[:,[0],:]
|
69 |
+
vocal_wav = prompt_token[:,[1],:]
|
70 |
+
bgm_wav = prompt_token[:,[2],:]
|
71 |
+
melody_is_wav = False
|
72 |
+
else:
|
73 |
+
pmt_wav = None
|
74 |
+
vocal_wav = None
|
75 |
+
bgm_wav = None
|
76 |
+
melody_is_wav = True
|
77 |
+
|
78 |
+
model_light = CodecLM_PL(self.cfg, self.pt_path)
|
79 |
+
model_light = model_light.eval()
|
80 |
+
model_light.audiolm.cfg = self.cfg
|
81 |
+
model = CodecLM(name = "tmp",
|
82 |
+
lm = model_light.audiolm,
|
83 |
+
audiotokenizer = None,
|
84 |
+
max_duration = self.max_duration,
|
85 |
+
seperate_tokenizer = None,
|
86 |
+
)
|
87 |
+
del model_light
|
88 |
+
model.lm = model.lm.cuda().to(torch.float16)
|
89 |
+
params = {**self.default_params, **params}
|
90 |
+
model.set_generation_params(**params)
|
91 |
+
|
92 |
+
generate_inp = {
|
93 |
+
'lyrics': [lyric.replace(" ", " ")],
|
94 |
+
'descriptions': [description],
|
95 |
+
'melody_wavs': pmt_wav,
|
96 |
+
'vocal_wavs': vocal_wav,
|
97 |
+
'bgm_wavs': bgm_wav,
|
98 |
+
'melody_is_wav': melody_is_wav,
|
99 |
+
}
|
100 |
+
|
101 |
+
with torch.autocast(device_type="cuda", dtype=torch.float16):
|
102 |
+
tokens = model.generate(**generate_inp, return_tokens=True)
|
103 |
+
|
104 |
+
del model
|
105 |
+
torch.cuda.empty_cache()
|
106 |
+
|
107 |
+
seperate_tokenizer = builders.get_audio_tokenizer_model(self.cfg.audio_tokenizer_checkpoint_sep, self.cfg)
|
108 |
+
seperate_tokenizer = seperate_tokenizer.eval().cuda()
|
109 |
+
model = CodecLM(name = "tmp",
|
110 |
+
lm = None,
|
111 |
+
audiotokenizer = None,
|
112 |
+
max_duration = self.max_duration,
|
113 |
+
seperate_tokenizer = seperate_tokenizer,
|
114 |
+
)
|
115 |
+
|
116 |
+
if tokens.shape[-1] > 3000:
|
117 |
+
tokens = tokens[..., :3000]
|
118 |
+
|
119 |
+
with torch.no_grad():
|
120 |
+
if melody_is_wav:
|
121 |
+
wav_seperate = model.generate_audio(tokens, pmt_wav, vocal_wav, bgm_wav)
|
122 |
+
else:
|
123 |
+
wav_seperate = model.generate_audio(tokens)
|
124 |
+
|
125 |
+
del seperate_tokenizer
|
126 |
+
del model
|
127 |
+
torch.cuda.empty_cache()
|
128 |
+
|
129 |
+
return wav_seperate[0]
|
tools/gradio/run.sh
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
export USER=root
|
2 |
+
export PYTHONDONTWRITEBYTECODE=1
|
3 |
+
export TRANSFORMERS_CACHE="$(pwd)/third_party/hub"
|
4 |
+
export NCCL_HOME=/usr/local/tccl
|
5 |
+
export PYTHONPATH="$(pwd)/codeclm/tokenizer/":"$(pwd)":"$(pwd)/codeclm/tokenizer/Flow1dVAE/":"$(pwd)/codeclm/tokenizer/":$PYTHONPATH
|
6 |
+
|
7 |
+
|
8 |
+
CKPT_PATH=$1
|
9 |
+
python3 tools/gradio/app.py $CKPT_PATH
|
tools/gradio/separator.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torchaudio
|
2 |
+
import os
|
3 |
+
import torch
|
4 |
+
from third_party.demucs.models.pretrained import get_model_from_yaml
|
5 |
+
|
6 |
+
|
7 |
+
class Separator(torch.nn.Module):
|
8 |
+
def __init__(self, dm_model_path='third_party/demucs/ckpt/htdemucs.pth', dm_config_path='third_party/demucs/ckpt/htdemucs.yaml', gpu_id=0) -> None:
|
9 |
+
super().__init__()
|
10 |
+
if torch.cuda.is_available() and gpu_id < torch.cuda.device_count():
|
11 |
+
self.device = torch.device(f"cuda:{gpu_id}")
|
12 |
+
else:
|
13 |
+
self.device = torch.device("cpu")
|
14 |
+
self.demucs_model = self.init_demucs_model(dm_model_path, dm_config_path)
|
15 |
+
|
16 |
+
def init_demucs_model(self, model_path, config_path):
|
17 |
+
model = get_model_from_yaml(config_path, model_path)
|
18 |
+
model.to(self.device)
|
19 |
+
model.eval()
|
20 |
+
return model
|
21 |
+
|
22 |
+
def load_audio(self, f):
|
23 |
+
a, fs = torchaudio.load(f)
|
24 |
+
if (fs != 48000):
|
25 |
+
a = torchaudio.functional.resample(a, fs, 48000)
|
26 |
+
if a.shape[-1] >= 48000*10:
|
27 |
+
a = a[..., :48000*10]
|
28 |
+
else:
|
29 |
+
a = torch.cat([a, a], -1)
|
30 |
+
return a[:, 0:48000*10]
|
31 |
+
|
32 |
+
def run(self, audio_path, output_dir='tmp', ext=".flac"):
|
33 |
+
os.makedirs(output_dir, exist_ok=True)
|
34 |
+
name, _ = os.path.splitext(os.path.split(audio_path)[-1])
|
35 |
+
output_paths = []
|
36 |
+
|
37 |
+
for stem in self.demucs_model.sources:
|
38 |
+
output_path = os.path.join(output_dir, f"{name}_{stem}{ext}")
|
39 |
+
if os.path.exists(output_path):
|
40 |
+
output_paths.append(output_path)
|
41 |
+
if len(output_paths) == 1: # 4
|
42 |
+
vocal_path = output_paths[0]
|
43 |
+
else:
|
44 |
+
drums_path, bass_path, other_path, vocal_path = self.demucs_model.separate(audio_path, output_dir, device=self.device)
|
45 |
+
for path in [drums_path, bass_path, other_path]:
|
46 |
+
os.remove(path)
|
47 |
+
full_audio = self.load_audio(audio_path)
|
48 |
+
vocal_audio = self.load_audio(vocal_path)
|
49 |
+
bgm_audio = full_audio - vocal_audio
|
50 |
+
return full_audio, vocal_audio, bgm_audio
|