waytan22 commited on
Commit
93f7efb
·
1 Parent(s): 1be5a09

fix some typo

Browse files
Files changed (34) hide show
  1. LICENSE +211 -211
  2. README.md +0 -2
  3. app.py +3 -3
  4. codeclm/models/codeclm.py +40 -53
  5. codeclm/tokenizer/Flow1dVAE/generate_septoken.py +3 -2
  6. codeclm/tokenizer/Flow1dVAE/libs/rvq/core_vq.py +366 -366
  7. codeclm/tokenizer/Flow1dVAE/model_1rvq.py +710 -710
  8. codeclm/tokenizer/Flow1dVAE/model_2rvq.py +774 -774
  9. codeclm/tokenizer/Flow1dVAE/model_4rvq.py +774 -774
  10. codeclm/tokenizer/Flow1dVAE/model_septoken.py +670 -670
  11. codeclm/tokenizer/Flow1dVAE/models/unet_2d_condition_additionalemb.py +0 -0
  12. codeclm/tokenizer/Flow1dVAE/models/unet_2d_condition_flow.py +0 -0
  13. codeclm/tokenizer/Flow1dVAE/models_gpt/models/tokenizer/pinyin/symbols.py +71 -71
  14. codeclm/tokenizer/Flow1dVAE/tools/infer_bsrnnvae441k.py +47 -47
  15. codeclm/tokenizer/Flow1dVAE/tools/infer_bsrnnvae441k_vocal.py +47 -47
  16. codeclm/tokenizer/Flow1dVAE/tools/infer_hifigan48k_speech.py +56 -56
  17. codeclm/tokenizer/Flow1dVAE/tools/infer_hifigan48k_vocal.py +57 -57
  18. codeclm/tokenizer/Flow1dVAE/tools/infer_vaehifigan48k.py +59 -59
  19. codeclm/tokenizer/Flow1dVAE/tools/infer_vaehifigan48k_soundmusic.py +61 -61
  20. codeclm/tokenizer/Flow1dVAE/tools/infer_vaehifigan48k_speech.py +58 -58
  21. codeclm/tokenizer/Flow1dVAE/tools/infer_vaehifigan48k_vocal.py +59 -59
  22. codeclm/tokenizer/Flow1dVAE/tools/mix.py +50 -50
  23. codeclm/tokenizer/Flow1dVAE/tools/torch_tools.py +142 -142
  24. codeclm/tokenizer/audio_tokenizer.py +2 -2
  25. generate_lowmem.py +240 -0
  26. generate_lowmem.sh +10 -0
  27. requirements.txt +24 -0
  28. requirements_nodeps.txt +13 -0
  29. sample/lyrics.jsonl +1 -1
  30. tools/gradio/app.py +236 -0
  31. tools/gradio/levo_inference.py +110 -0
  32. tools/gradio/levo_inference_lowmem.py +129 -0
  33. tools/gradio/run.sh +9 -0
  34. tools/gradio/separator.py +50 -0
LICENSE CHANGED
@@ -1,211 +1,211 @@
1
- Tencent is pleased to support the open source community by making SongGeneration available.
2
-
3
- Copyright (C) 2025 Tencent. All rights reserved.
4
-
5
- SongGeneration is licensed under the License Terms of SongGeneration except for the third-party components listed below, which is licensed under different terms. SongGeneration does not impose any additional limitations beyond what is outlined in the respective licenses of these third-party components. Users must comply with all terms and conditions of original licenses of these third-party components and must ensure that the usage of the third party components adheres to all relevant laws and regulations.
6
-
7
-
8
- License Terms of SongGeneration:
9
- --------------------------------------------------------------------
10
-
11
- Permission is hereby granted, free of charge, to any person obtaining a copy of this Software and associated documentation files, to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, and/or sublicense copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
12
-
13
- - You agree to use the SongGeneration only for academic, research and education purposes, and refrain from using it for any commercial or production purposes under any circumstances.
14
-
15
- - The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
16
-
17
- For avoidance of doubts, "Software" means the SongGeneration inference-enabling code and the weights made available under this license excluding any pre-trained data and other AI components.
18
-
19
- THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
20
-
21
-
22
- Other dependencies and licenses:
23
-
24
-
25
- Open Source Software Licensed under the MIT License and Other Licenses of the Third-Party Components therein:
26
- --------------------------------------------------------------------
27
- 1. stable_audio_tools
28
- Copyright (c) 2023 Stability AI
29
-
30
-
31
- Terms of the MIT:
32
- --------------------------------------------------------------------
33
- Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
34
-
35
- The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
36
-
37
- THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
38
-
39
- For the license of other third party components, please refer to the following URL:
40
- https://github.com/Stability-AI/stable-audio-tools/tree/main/LICENSES
41
-
42
-
43
- Open Source Software Licensed under the MIT License:
44
- --------------------------------------------------------------------
45
- 1. demucs
46
- Copyright (c) Meta Platforms, Inc. and affiliates.
47
-
48
-
49
- A copy of the MIT is included in this file.
50
-
51
-
52
-
53
- Open Source Software Licensed under the BSD 3-Clause License and Other Licenses of the Third-Party Components therein:
54
- --------------------------------------------------------------------
55
- 1. torch
56
- From PyTorch:
57
-
58
- Copyright (c) 2016- Facebook, Inc (Adam Paszke)
59
- Copyright (c) 2014- Facebook, Inc (Soumith Chintala)
60
- Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert)
61
- Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu)
62
- Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu)
63
- Copyright (c) 2011-2013 NYU (Clement Farabet)
64
- Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston)
65
- Copyright (c) 2006 Idiap Research Institute (Samy Bengio)
66
- Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz)
67
-
68
- From Caffe2:
69
-
70
- Copyright (c) 2016-present, Facebook Inc. All rights reserved.
71
-
72
- All contributions by Facebook:
73
- Copyright (c) 2016 Facebook Inc.
74
-
75
- All contributions by Google:
76
- Copyright (c) 2015 Google Inc.
77
- All rights reserved.
78
-
79
- All contributions by Yangqing Jia:
80
- Copyright (c) 2015 Yangqing Jia
81
- All rights reserved.
82
-
83
- All contributions by Kakao Brain:
84
- Copyright 2019-2020 Kakao Brain
85
-
86
- All contributions by Cruise LLC:
87
- Copyright (c) 2022 Cruise LLC.
88
- All rights reserved.
89
-
90
- All contributions from Caffe:
91
- Copyright(c) 2013, 2014, 2015, the respective contributors
92
- All rights reserved.
93
-
94
- All other contributions:
95
- Copyright(c) 2015, 2016 the respective contributors
96
- All rights reserved.
97
-
98
- Caffe2 uses a copyright model similar to Caffe: each contributor holds
99
- copyright over their contributions to Caffe2. The project versioning records
100
- all such contribution and copyright details. If a contributor wants to further
101
- mark their specific copyright on a particular contribution, they should
102
- indicate their copyright solely in the commit message of the change when it is
103
- committed.
104
-
105
- All rights reserved.
106
-
107
-
108
- Terms of the BSD 3-Clause:
109
- --------------------------------------------------------------------
110
- Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
111
-
112
- 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
113
-
114
- 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
115
-
116
- 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
117
-
118
- THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
119
-
120
- For the license of other third party components, please refer to the following URL:
121
- https://github.com/pytorch/pytorch/blob/v2.0.1/NOTICE
122
-
123
-
124
- Open Source Software Licensed under the BSD 2-Clause License and Other Licenses of the Third-Party Components therein:
125
- --------------------------------------------------------------------
126
- 1. torchaudio
127
- Copyright (c) 2017 Facebook Inc. (Soumith Chintala),
128
- All rights reserved.
129
-
130
-
131
- Terms of the BSD 2-Clause:
132
- --------------------------------------------------------------------
133
- Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
134
-
135
- * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
136
-
137
- * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
138
-
139
- THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
140
-
141
- For the license of other third party components, please refer to the following URL:
142
- https://github.com/pytorch/audio/blob/v2.0.2/LICENSE
143
-
144
-
145
- Open Source Software License under the Apache License Version 2.0:
146
- --------------------------------------------------------------------
147
- 1. huggingface-hub
148
- Copyright (c) huggingface-hub original author and authors
149
-
150
- 2. transformers
151
- Copyright 2018- The Hugging Face team. All rights reserved.
152
-
153
-
154
- Terms of the Apache License Version 2.0:
155
- --------------------------------------------------------------------
156
- Apache License
157
-
158
- Version 2.0, January 2004
159
-
160
- http://www.apache.org/licenses/
161
-
162
- TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
163
- 1. Definitions.
164
-
165
- "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document.
166
-
167
- "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License.
168
-
169
- "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity.
170
-
171
- "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License.
172
-
173
- "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files.
174
-
175
- "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types.
176
-
177
- "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below).
178
-
179
- "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof.
180
-
181
- "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution."
182
-
183
- "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work.
184
-
185
- 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form.
186
-
187
- 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed.
188
-
189
- 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions:
190
-
191
- You must give any other recipients of the Work or Derivative Works a copy of this License; and
192
-
193
- You must cause any modified files to carry prominent notices stating that You changed the files; and
194
-
195
- You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and
196
-
197
- If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License.
198
-
199
- You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License.
200
-
201
- 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions.
202
-
203
- 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file.
204
-
205
- 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License.
206
-
207
- 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages.
208
-
209
- 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability.
210
-
211
- END OF TERMS AND CONDITIONS
 
1
+ Tencent is pleased to support the open source community by making SongGeneration available.
2
+
3
+ Copyright (C) 2025 Tencent. All rights reserved.
4
+
5
+ SongGeneration is licensed under the License Terms of SongGeneration except for the third-party components listed below, which is licensed under different terms. SongGeneration does not impose any additional limitations beyond what is outlined in the respective licenses of these third-party components. Users must comply with all terms and conditions of original licenses of these third-party components and must ensure that the usage of the third party components adheres to all relevant laws and regulations.
6
+
7
+
8
+ License Terms of SongGeneration:
9
+ --------------------------------------------------------------------
10
+
11
+ Permission is hereby granted, free of charge, to any person obtaining a copy of this Software and associated documentation files, to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, and/or sublicense copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
12
+
13
+ - You agree to use the SongGeneration only for academic, research and education purposes, and refrain from using it for any commercial or production purposes under any circumstances.
14
+
15
+ - The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
16
+
17
+ For avoidance of doubts, "Software" means the SongGeneration inference-enabling code and the weights made available under this license excluding any pre-trained data and other AI components.
18
+
19
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
20
+
21
+
22
+ Other dependencies and licenses:
23
+
24
+
25
+ Open Source Software Licensed under the MIT License and Other Licenses of the Third-Party Components therein:
26
+ --------------------------------------------------------------------
27
+ 1. stable_audio_tools
28
+ Copyright (c) 2023 Stability AI
29
+
30
+
31
+ Terms of the MIT:
32
+ --------------------------------------------------------------------
33
+ Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
34
+
35
+ The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
36
+
37
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
38
+
39
+ For the license of other third party components, please refer to the following URL:
40
+ https://github.com/Stability-AI/stable-audio-tools/tree/main/LICENSES
41
+
42
+
43
+ Open Source Software Licensed under the MIT License:
44
+ --------------------------------------------------------------------
45
+ 1. demucs
46
+ Copyright (c) Meta Platforms, Inc. and affiliates.
47
+
48
+
49
+ A copy of the MIT is included in this file.
50
+
51
+
52
+
53
+ Open Source Software Licensed under the BSD 3-Clause License and Other Licenses of the Third-Party Components therein:
54
+ --------------------------------------------------------------------
55
+ 1. torch
56
+ From PyTorch:
57
+
58
+ Copyright (c) 2016- Facebook, Inc (Adam Paszke)
59
+ Copyright (c) 2014- Facebook, Inc (Soumith Chintala)
60
+ Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert)
61
+ Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu)
62
+ Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu)
63
+ Copyright (c) 2011-2013 NYU (Clement Farabet)
64
+ Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston)
65
+ Copyright (c) 2006 Idiap Research Institute (Samy Bengio)
66
+ Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz)
67
+
68
+ From Caffe2:
69
+
70
+ Copyright (c) 2016-present, Facebook Inc. All rights reserved.
71
+
72
+ All contributions by Facebook:
73
+ Copyright (c) 2016 Facebook Inc.
74
+
75
+ All contributions by Google:
76
+ Copyright (c) 2015 Google Inc.
77
+ All rights reserved.
78
+
79
+ All contributions by Yangqing Jia:
80
+ Copyright (c) 2015 Yangqing Jia
81
+ All rights reserved.
82
+
83
+ All contributions by Kakao Brain:
84
+ Copyright 2019-2020 Kakao Brain
85
+
86
+ All contributions by Cruise LLC:
87
+ Copyright (c) 2022 Cruise LLC.
88
+ All rights reserved.
89
+
90
+ All contributions from Caffe:
91
+ Copyright(c) 2013, 2014, 2015, the respective contributors
92
+ All rights reserved.
93
+
94
+ All other contributions:
95
+ Copyright(c) 2015, 2016 the respective contributors
96
+ All rights reserved.
97
+
98
+ Caffe2 uses a copyright model similar to Caffe: each contributor holds
99
+ copyright over their contributions to Caffe2. The project versioning records
100
+ all such contribution and copyright details. If a contributor wants to further
101
+ mark their specific copyright on a particular contribution, they should
102
+ indicate their copyright solely in the commit message of the change when it is
103
+ committed.
104
+
105
+ All rights reserved.
106
+
107
+
108
+ Terms of the BSD 3-Clause:
109
+ --------------------------------------------------------------------
110
+ Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
111
+
112
+ 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
113
+
114
+ 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
115
+
116
+ 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
117
+
118
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
119
+
120
+ For the license of other third party components, please refer to the following URL:
121
+ https://github.com/pytorch/pytorch/blob/v2.0.1/NOTICE
122
+
123
+
124
+ Open Source Software Licensed under the BSD 2-Clause License and Other Licenses of the Third-Party Components therein:
125
+ --------------------------------------------------------------------
126
+ 1. torchaudio
127
+ Copyright (c) 2017 Facebook Inc. (Soumith Chintala),
128
+ All rights reserved.
129
+
130
+
131
+ Terms of the BSD 2-Clause:
132
+ --------------------------------------------------------------------
133
+ Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
134
+
135
+ * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
136
+
137
+ * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
138
+
139
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
140
+
141
+ For the license of other third party components, please refer to the following URL:
142
+ https://github.com/pytorch/audio/blob/v2.0.2/LICENSE
143
+
144
+
145
+ Open Source Software License under the Apache License Version 2.0:
146
+ --------------------------------------------------------------------
147
+ 1. huggingface-hub
148
+ Copyright (c) huggingface-hub original author and authors
149
+
150
+ 2. transformers
151
+ Copyright 2018- The Hugging Face team. All rights reserved.
152
+
153
+
154
+ Terms of the Apache License Version 2.0:
155
+ --------------------------------------------------------------------
156
+ Apache License
157
+
158
+ Version 2.0, January 2004
159
+
160
+ http://www.apache.org/licenses/
161
+
162
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
163
+ 1. Definitions.
164
+
165
+ "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document.
166
+
167
+ "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License.
168
+
169
+ "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity.
170
+
171
+ "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License.
172
+
173
+ "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files.
174
+
175
+ "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types.
176
+
177
+ "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below).
178
+
179
+ "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof.
180
+
181
+ "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution."
182
+
183
+ "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work.
184
+
185
+ 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form.
186
+
187
+ 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed.
188
+
189
+ 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions:
190
+
191
+ You must give any other recipients of the Work or Derivative Works a copy of this License; and
192
+
193
+ You must cause any modified files to carry prominent notices stating that You changed the files; and
194
+
195
+ You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and
196
+
197
+ If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License.
198
+
199
+ You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License.
200
+
201
+ 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions.
202
+
203
+ 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file.
204
+
205
+ 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License.
206
+
207
+ 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages.
208
+
209
+ 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability.
210
+
211
+ END OF TERMS AND CONDITIONS
README.md CHANGED
@@ -7,11 +7,9 @@ sdk: docker
7
  app_port: 7860
8
  ---
9
 
10
-
11
  <p align="center">
12
  <a href="https://levo-demo.github.io/">Demo</a> &nbsp;|&nbsp; <a href="https://arxiv.org/abs/2506.07520">Paper</a> &nbsp;|&nbsp; <a href="https://github.com/tencent-ailab/songgeneration">Code</a>
13
  </p>
14
-
15
  This repository is the official weight repository for LeVo: High-Quality Song Generation with Multi-Preference Alignment. In this repository, we provide the SongGeneration model, inference scripts, and the checkpoint that has been trained on the Million Song Dataset.
16
 
17
  ## Overview
 
7
  app_port: 7860
8
  ---
9
 
 
10
  <p align="center">
11
  <a href="https://levo-demo.github.io/">Demo</a> &nbsp;|&nbsp; <a href="https://arxiv.org/abs/2506.07520">Paper</a> &nbsp;|&nbsp; <a href="https://github.com/tencent-ailab/songgeneration">Code</a>
12
  </p>
 
13
  This repository is the official weight repository for LeVo: High-Quality Song Generation with Multi-Preference Alignment. In this repository, we provide the SongGeneration model, inference scripts, and the checkpoint that has been trained on the Million Song Dataset.
14
 
15
  ## Overview
app.py CHANGED
@@ -124,9 +124,9 @@ def generate_song(lyric, description=None, prompt_audio=None, genre=None, cfg_co
124
 
125
 
126
  # 创建Gradio界面
127
- with gr.Blocks(title="SongGeration Demo Space") as demo:
128
- gr.Markdown("# 🎵 SongGeration Demo Space")
129
- gr.Markdown("Demo interface for the song generation model. Provide a lyrics, and optionally an audio or text prompt, to generate a custom song.")
130
 
131
  with gr.Row():
132
  with gr.Column():
 
124
 
125
 
126
  # 创建Gradio界面
127
+ with gr.Blocks(title="SongGeneration Demo Space") as demo:
128
+ gr.Markdown("# 🎵 SongGeneration Demo Space")
129
+ gr.Markdown("Demo interface for the song generation model. Provide a lyrics, and optionally an audio or text prompt, to generate a custom song. The code is in [GIT](https://github.com/tencent-ailab/SongGeneration)")
130
 
131
  with gr.Row():
132
  with gr.Column():
codeclm/models/codeclm.py CHANGED
@@ -36,6 +36,10 @@ class CodecLM:
36
  max_duration: tp.Optional[float] = None, seperate_tokenizer: AudioTokenizer = None):
37
  self.name = name
38
  self.audiotokenizer = audiotokenizer
 
 
 
 
39
  self.lm = lm
40
  self.seperate_tokenizer = seperate_tokenizer
41
  # import pdb; pdb.set_trace()
@@ -47,7 +51,7 @@ class CodecLM:
47
  assert max_duration is not None
48
 
49
  self.max_duration: float = max_duration
50
- self.device = next(iter(lm.parameters())).device
51
  self.generation_params: dict = {}
52
  # self.set_generation_params(duration=15) # 15 seconds by default
53
  self.set_generation_params(duration=15, extend_stride=self.max_duration // 2)
@@ -57,23 +61,6 @@ class CodecLM:
57
  else:
58
  self.autocast = TorchAutocast(enabled=False)
59
 
60
-
61
-
62
- @property
63
- def frame_rate(self) -> float:
64
- """Roughly the number of AR steps per seconds."""
65
- return self.audiotokenizer.frame_rate
66
-
67
- @property
68
- def sample_rate(self) -> int:
69
- """Sample rate of the generated audio."""
70
- return self.audiotokenizer.sample_rate
71
-
72
- @property
73
- def audio_channels(self) -> int:
74
- """Audio channels of the generated audio."""
75
- return self.audiotokenizer.channels
76
-
77
  def set_generation_params(self, use_sampling: bool = True, top_k: int = 250,
78
  top_p: float = 0.0, temperature: float = 1.0,
79
  duration: float = 30.0, cfg_coef: float = 3.0,
@@ -185,7 +172,7 @@ class CodecLM:
185
  assert len(lyrics) == 1
186
  texts = [lyric for lyric in lyrics]
187
  audio_qt_embs = []
188
- target_melody_token_len = self.lm.cfg.prompt_len * self.audiotokenizer.frame_rate
189
  # import pdb; pdb.set_trace()
190
  if melody_wavs is None:
191
  melody_tokens = torch.full((1,1,target_melody_token_len), 16385, device=self.device).long()
@@ -207,39 +194,39 @@ class CodecLM:
207
  melody_tokens = melody_tokens[...,:target_melody_token_len]
208
  elif melody_tokens.shape[-1] < target_melody_token_len:
209
  melody_tokens = torch.cat([melody_tokens, torch.full((1,1,target_melody_token_len - melody_tokens.shape[-1]), 16385, device=self.device).long()], dim=-1)
210
- if self.seperate_tokenizer is not None:
211
- if bgm_wavs is None:
212
- assert vocal_wavs is None, "vocal_wavs is not None when bgm_wavs is None"
213
- bgm_tokens = torch.full((1,1,target_melody_token_len), 16385, device=self.device).long()
214
- vocal_tokens = torch.full((1,1,target_melody_token_len), 16385, device=self.device).long()
 
 
 
 
 
 
 
 
 
 
215
  else:
216
- assert vocal_wavs is not None, "vocal_wavs is None when bgm_wavs is not None"
217
- if type(vocal_wavs) == list:
218
- vocal_wavs = torch.stack(vocal_wavs, dim=0)
219
- if type(bgm_wavs) == list:
220
- bgm_wavs = torch.stack(bgm_wavs, dim=0)
221
- vocal_wavs = vocal_wavs.to(self.device)
222
- bgm_wavs = bgm_wavs.to(self.device)
223
- if melody_is_wav:
224
- vocal_tokens, bgm_tokens = self.seperate_tokenizer.encode(vocal_wavs, bgm_wavs)
225
- else:
226
- vocal_tokens = vocal_wavs
227
- bgm_tokens = bgm_wavs
228
- assert len(vocal_tokens.shape) == len(bgm_tokens.shape) == 3, \
229
- f"vocal and bgm tokens should have a shape [B, C, T]! " \
230
- f"got vocal len={vocal_tokens.shape}, and bgm len={bgm_tokens.shape}"
231
- assert vocal_tokens.shape[-1] == bgm_tokens.shape[-1], \
232
- f"vocal and bgm tokens should have the same length! " \
233
- f"got vocal len={vocal_tokens.shape[-1]}, and bgm len={bgm_tokens.shape[-1]}"
234
- if bgm_tokens.shape[-1] > target_melody_token_len:
235
- bgm_tokens = bgm_tokens[...,:target_melody_token_len]
236
- elif bgm_tokens.shape[-1] < target_melody_token_len:
237
- bgm_tokens = torch.cat([bgm_tokens, torch.full((1,1,target_melody_token_len - bgm_tokens.shape[-1]), 16385, device=self.device).long()], dim=-1)
238
- if vocal_tokens.shape[-1] > target_melody_token_len:
239
- vocal_tokens = vocal_tokens[...,:target_melody_token_len]
240
- elif vocal_tokens.shape[-1] < target_melody_token_len:
241
- vocal_tokens = torch.cat([vocal_tokens, torch.full((1,1,target_melody_token_len - vocal_tokens.shape[-1]), 16385, device=self.device).long()], dim=-1)
242
- melody_tokens = torch.cat([melody_tokens, vocal_tokens, bgm_tokens], dim=1)
243
  assert melody_tokens.shape[-1] == target_melody_token_len
244
  audio_qt_embs = melody_tokens.long()
245
  return texts, audio_qt_embs
@@ -284,7 +271,7 @@ class CodecLM:
284
  return gen_tokens
285
 
286
  @torch.no_grad()
287
- def generate_audio(self, gen_tokens: torch.Tensor, prompt=None, vocal_prompt=None, bgm_prompt=None):
288
  """Generate Audio from tokens"""
289
  assert gen_tokens.dim() == 3
290
  if self.seperate_tokenizer is not None:
@@ -292,7 +279,7 @@ class CodecLM:
292
  gen_tokens_vocal = gen_tokens[:, [1], :]
293
  gen_tokens_bgm = gen_tokens[:, [2], :]
294
  # gen_audio_song = self.audiotokenizer.decode(gen_tokens_song, prompt)
295
- gen_audio_seperate = self.seperate_tokenizer.decode([gen_tokens_vocal, gen_tokens_bgm], vocal_prompt, bgm_prompt)
296
  return gen_audio_seperate
297
  else:
298
  gen_audio = self.audiotokenizer.decode(gen_tokens, prompt)
 
36
  max_duration: tp.Optional[float] = None, seperate_tokenizer: AudioTokenizer = None):
37
  self.name = name
38
  self.audiotokenizer = audiotokenizer
39
+ if self.audiotokenizer:
40
+ self.frame_rate = self.audiotokenizer.frame_rate
41
+ else:
42
+ self.frame_rate = 25
43
  self.lm = lm
44
  self.seperate_tokenizer = seperate_tokenizer
45
  # import pdb; pdb.set_trace()
 
51
  assert max_duration is not None
52
 
53
  self.max_duration: float = max_duration
54
+ self.device = torch.device("cuda")
55
  self.generation_params: dict = {}
56
  # self.set_generation_params(duration=15) # 15 seconds by default
57
  self.set_generation_params(duration=15, extend_stride=self.max_duration // 2)
 
61
  else:
62
  self.autocast = TorchAutocast(enabled=False)
63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  def set_generation_params(self, use_sampling: bool = True, top_k: int = 250,
65
  top_p: float = 0.0, temperature: float = 1.0,
66
  duration: float = 30.0, cfg_coef: float = 3.0,
 
172
  assert len(lyrics) == 1
173
  texts = [lyric for lyric in lyrics]
174
  audio_qt_embs = []
175
+ target_melody_token_len = self.lm.cfg.prompt_len * self.frame_rate
176
  # import pdb; pdb.set_trace()
177
  if melody_wavs is None:
178
  melody_tokens = torch.full((1,1,target_melody_token_len), 16385, device=self.device).long()
 
194
  melody_tokens = melody_tokens[...,:target_melody_token_len]
195
  elif melody_tokens.shape[-1] < target_melody_token_len:
196
  melody_tokens = torch.cat([melody_tokens, torch.full((1,1,target_melody_token_len - melody_tokens.shape[-1]), 16385, device=self.device).long()], dim=-1)
197
+
198
+ if bgm_wavs is None:
199
+ assert vocal_wavs is None, "vocal_wavs is not None when bgm_wavs is None"
200
+ bgm_tokens = torch.full((1,1,target_melody_token_len), 16385, device=self.device).long()
201
+ vocal_tokens = torch.full((1,1,target_melody_token_len), 16385, device=self.device).long()
202
+ else:
203
+ assert vocal_wavs is not None, "vocal_wavs is None when bgm_wavs is not None"
204
+ if type(vocal_wavs) == list:
205
+ vocal_wavs = torch.stack(vocal_wavs, dim=0)
206
+ if type(bgm_wavs) == list:
207
+ bgm_wavs = torch.stack(bgm_wavs, dim=0)
208
+ vocal_wavs = vocal_wavs.to(self.device)
209
+ bgm_wavs = bgm_wavs.to(self.device)
210
+ if melody_is_wav:
211
+ vocal_tokens, bgm_tokens = self.seperate_tokenizer.encode(vocal_wavs, bgm_wavs)
212
  else:
213
+ vocal_tokens = vocal_wavs
214
+ bgm_tokens = bgm_wavs
215
+ assert len(vocal_tokens.shape) == len(bgm_tokens.shape) == 3, \
216
+ f"vocal and bgm tokens should have a shape [B, C, T]! " \
217
+ f"got vocal len={vocal_tokens.shape}, and bgm len={bgm_tokens.shape}"
218
+ assert vocal_tokens.shape[-1] == bgm_tokens.shape[-1], \
219
+ f"vocal and bgm tokens should have the same length! " \
220
+ f"got vocal len={vocal_tokens.shape[-1]}, and bgm len={bgm_tokens.shape[-1]}"
221
+ if bgm_tokens.shape[-1] > target_melody_token_len:
222
+ bgm_tokens = bgm_tokens[...,:target_melody_token_len]
223
+ elif bgm_tokens.shape[-1] < target_melody_token_len:
224
+ bgm_tokens = torch.cat([bgm_tokens, torch.full((1,1,target_melody_token_len - bgm_tokens.shape[-1]), 16385, device=self.device).long()], dim=-1)
225
+ if vocal_tokens.shape[-1] > target_melody_token_len:
226
+ vocal_tokens = vocal_tokens[...,:target_melody_token_len]
227
+ elif vocal_tokens.shape[-1] < target_melody_token_len:
228
+ vocal_tokens = torch.cat([vocal_tokens, torch.full((1,1,target_melody_token_len - vocal_tokens.shape[-1]), 16385, device=self.device).long()], dim=-1)
229
+ melody_tokens = torch.cat([melody_tokens, vocal_tokens, bgm_tokens], dim=1)
 
 
 
 
 
 
 
 
 
 
230
  assert melody_tokens.shape[-1] == target_melody_token_len
231
  audio_qt_embs = melody_tokens.long()
232
  return texts, audio_qt_embs
 
271
  return gen_tokens
272
 
273
  @torch.no_grad()
274
+ def generate_audio(self, gen_tokens: torch.Tensor, prompt=None, vocal_prompt=None, bgm_prompt=None, chunked=False):
275
  """Generate Audio from tokens"""
276
  assert gen_tokens.dim() == 3
277
  if self.seperate_tokenizer is not None:
 
279
  gen_tokens_vocal = gen_tokens[:, [1], :]
280
  gen_tokens_bgm = gen_tokens[:, [2], :]
281
  # gen_audio_song = self.audiotokenizer.decode(gen_tokens_song, prompt)
282
+ gen_audio_seperate = self.seperate_tokenizer.decode([gen_tokens_vocal, gen_tokens_bgm], vocal_prompt, bgm_prompt, chunked=chunked)
283
  return gen_audio_seperate
284
  else:
285
  gen_audio = self.audiotokenizer.decode(gen_tokens, prompt)
codeclm/tokenizer/Flow1dVAE/generate_septoken.py CHANGED
@@ -173,7 +173,7 @@ class Tango:
173
  return codes_vocal, codes_bgm
174
 
175
  @torch.no_grad()
176
- def code2sound(self, codes, prompt_vocal=None, prompt_bgm=None, duration=40, guidance_scale=1.5, num_steps=20, disable_progress=False):
177
  codes_vocal,codes_bgm = codes
178
  codes_vocal = codes_vocal.to(self.device)
179
  codes_bgm = codes_bgm.to(self.device)
@@ -268,11 +268,12 @@ class Tango:
268
  min_samples = int(min_samples * self.sample_rate // 1000 * 40)
269
  hop_samples = int(hop_samples * self.sample_rate // 1000 * 40)
270
  ovlp_samples = min_samples - hop_samples
 
271
  with torch.no_grad():
272
  output = None
273
  for i in range(len(latent_list)):
274
  latent = latent_list[i]
275
- cur_output = self.vae.decode_audio(latent)[0].detach().cpu()
276
 
277
  if output is None:
278
  output = cur_output
 
173
  return codes_vocal, codes_bgm
174
 
175
  @torch.no_grad()
176
+ def code2sound(self, codes, prompt_vocal=None, prompt_bgm=None, duration=40, guidance_scale=1.5, num_steps=20, disable_progress=False, chunked=False):
177
  codes_vocal,codes_bgm = codes
178
  codes_vocal = codes_vocal.to(self.device)
179
  codes_bgm = codes_bgm.to(self.device)
 
268
  min_samples = int(min_samples * self.sample_rate // 1000 * 40)
269
  hop_samples = int(hop_samples * self.sample_rate // 1000 * 40)
270
  ovlp_samples = min_samples - hop_samples
271
+ torch.cuda.empty_cache()
272
  with torch.no_grad():
273
  output = None
274
  for i in range(len(latent_list)):
275
  latent = latent_list[i]
276
+ cur_output = self.vae.decode_audio(latent, chunked=chunked)[0].detach().cpu()
277
 
278
  if output is None:
279
  output = cur_output
codeclm/tokenizer/Flow1dVAE/libs/rvq/core_vq.py CHANGED
@@ -1,366 +1,366 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
- #
7
- # This implementation is inspired from
8
- # https://github.com/lucidrains/vector-quantize-pytorch
9
- # which is released under MIT License. Hereafter, the original license:
10
- # MIT License
11
- #
12
- # Copyright (c) 2020 Phil Wang
13
- #
14
- # Permission is hereby granted, free of charge, to any person obtaining a copy
15
- # of this software and associated documentation files (the "Software"), to deal
16
- # in the Software without restriction, including without limitation the rights
17
- # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
18
- # copies of the Software, and to permit persons to whom the Software is
19
- # furnished to do so, subject to the following conditions:
20
- #
21
- # The above copyright notice and this permission notice shall be included in all
22
- # copies or substantial portions of the Software.
23
- #
24
- # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
25
- # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
26
- # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
27
- # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
28
- # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
29
- # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
30
- # SOFTWARE.
31
-
32
- """Core vector quantization implementation."""
33
-
34
- import typing as tp
35
-
36
- from einops import rearrange, repeat
37
- import torch
38
- from torch import nn
39
- import torch.nn.functional as F
40
-
41
- # from .. import distrib
42
-
43
-
44
- def default(val: tp.Any, d: tp.Any) -> tp.Any:
45
- return val if val is not None else d
46
-
47
-
48
- def ema_inplace(moving_avg, new, decay: float):
49
- moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
50
-
51
-
52
- def laplace_smoothing(x, n_categories: int, epsilon: float = 1e-5):
53
- return (x + epsilon) / (x.sum() + n_categories * epsilon)
54
-
55
-
56
- def uniform_init(*shape: int):
57
- t = torch.empty(shape)
58
- nn.init.kaiming_uniform_(t)
59
- return t
60
-
61
-
62
- def sample_vectors(samples, num: int):
63
- num_samples, device = samples.shape[0], samples.device
64
-
65
- if num_samples >= num:
66
- indices = torch.randperm(num_samples, device=device)[:num]
67
- else:
68
- indices = torch.randint(0, num_samples, (num,), device=device)
69
-
70
- return samples[indices]
71
-
72
-
73
- def kmeans(samples, num_clusters: int, num_iters: int = 10):
74
- dim, dtype = samples.shape[-1], samples.dtype
75
-
76
- means = sample_vectors(samples, num_clusters)
77
-
78
- for _ in range(num_iters):
79
- diffs = rearrange(samples, "n d -> n () d") - rearrange(
80
- means, "c d -> () c d"
81
- )
82
- dists = -(diffs ** 2).sum(dim=-1)
83
-
84
- buckets = dists.max(dim=-1).indices
85
- bins = torch.bincount(buckets, minlength=num_clusters)
86
- zero_mask = bins == 0
87
- bins_min_clamped = bins.masked_fill(zero_mask, 1)
88
-
89
- new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype)
90
- new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=dim), samples)
91
- new_means = new_means / bins_min_clamped[..., None]
92
-
93
- means = torch.where(zero_mask[..., None], means, new_means)
94
-
95
- return means, bins
96
-
97
-
98
- class EuclideanCodebook(nn.Module):
99
- """Codebook with Euclidean distance.
100
- Args:
101
- dim (int): Dimension.
102
- codebook_size (int): Codebook size.
103
- kmeans_init (bool): Whether to use k-means to initialize the codebooks.
104
- If set to true, run the k-means algorithm on the first training batch and use
105
- the learned centroids as initialization.
106
- kmeans_iters (int): Number of iterations used for k-means algorithm at initialization.
107
- decay (float): Decay for exponential moving average over the codebooks.
108
- epsilon (float): Epsilon value for numerical stability.
109
- threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
110
- that have an exponential moving average cluster size less than the specified threshold with
111
- randomly selected vector from the current batch.
112
- """
113
- def __init__(
114
- self,
115
- dim: int,
116
- codebook_size: int,
117
- kmeans_init: int = False,
118
- kmeans_iters: int = 10,
119
- decay: float = 0.99,
120
- epsilon: float = 1e-5,
121
- threshold_ema_dead_code: int = 2,
122
- ):
123
- super().__init__()
124
- self.decay = decay
125
- init_fn: tp.Union[tp.Callable[..., torch.Tensor], tp.Any] = uniform_init if not kmeans_init else torch.zeros
126
- embed = init_fn(codebook_size, dim)
127
-
128
- self.codebook_size = codebook_size
129
-
130
- self.kmeans_iters = kmeans_iters
131
- self.epsilon = epsilon
132
- self.threshold_ema_dead_code = threshold_ema_dead_code
133
-
134
- self.register_buffer("inited", torch.Tensor([not kmeans_init]))
135
- self.register_buffer("cluster_size", torch.zeros(codebook_size))
136
- self.register_buffer("embed", embed)
137
- self.register_buffer("embed_avg", embed.clone())
138
-
139
- self.runed_steps = 0
140
- self.stop_steps = 50_000
141
-
142
- @torch.jit.ignore
143
- def init_embed_(self, data):
144
- if self.inited:
145
- return
146
-
147
- embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters)
148
- self.embed.data.copy_(embed)
149
- self.embed_avg.data.copy_(embed.clone())
150
- self.cluster_size.data.copy_(cluster_size)
151
- self.inited.data.copy_(torch.Tensor([True]))
152
- # Make sure all buffers across workers are in sync after initialization
153
- distrib.broadcast_tensors(self.buffers())
154
-
155
- def replace_(self, samples, mask):
156
- modified_codebook = torch.where(
157
- mask[..., None], sample_vectors(samples, self.codebook_size), self.embed
158
- )
159
- self.embed.data.copy_(modified_codebook)
160
-
161
- def expire_codes_(self, batch_samples):
162
- if self.threshold_ema_dead_code == 0:
163
- return
164
-
165
- expired_codes = self.cluster_size < self.threshold_ema_dead_code
166
- if not torch.any(expired_codes):
167
- return
168
-
169
- batch_samples = rearrange(batch_samples, "... d -> (...) d")
170
- self.replace_(batch_samples, mask=expired_codes)
171
- # distrib.broadcast_tensors(self.buffers())
172
-
173
- def preprocess(self, x):
174
- x = rearrange(x, "... d -> (...) d")
175
- return x
176
-
177
- def quantize(self, x):
178
- embed = self.embed.t()
179
- dist = -(
180
- x.pow(2).sum(1, keepdim=True)
181
- - 2 * x @ embed
182
- + embed.pow(2).sum(0, keepdim=True)
183
- )
184
- embed_ind = dist.max(dim=-1).indices
185
- return embed_ind
186
-
187
- def postprocess_emb(self, embed_ind, shape):
188
- return embed_ind.view(*shape[:-1])
189
-
190
- def dequantize(self, embed_ind):
191
- quantize = F.embedding(embed_ind, self.embed)
192
- return quantize
193
-
194
- def encode(self, x):
195
- shape = x.shape
196
- # pre-process
197
- x = self.preprocess(x)
198
- # quantize
199
- embed_ind = self.quantize(x)
200
- # post-process
201
- embed_ind = self.postprocess_emb(embed_ind, shape)
202
- return embed_ind
203
-
204
- def decode(self, embed_ind):
205
- quantize = self.dequantize(embed_ind)
206
- return quantize
207
-
208
- def forward(self, x):
209
- shape, dtype = x.shape, x.dtype
210
- x = self.preprocess(x)
211
- # self.init_embed_(x)
212
-
213
- embed_ind = self.quantize(x)
214
- embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype)
215
- embed_ind = self.postprocess_emb(embed_ind, shape)
216
- quantize = self.dequantize(embed_ind)
217
- self.runed_steps += 1
218
-
219
- if self.training and self.runed_steps < self.stop_steps:
220
- # We do the expiry of code at that point as buffers are in sync
221
- # and all the workers will take the same decision.
222
- self.expire_codes_(x)
223
- ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay)
224
- embed_sum = x.t() @ embed_onehot
225
- ema_inplace(self.embed_avg, embed_sum.t(), self.decay)
226
- cluster_size = (
227
- laplace_smoothing(self.cluster_size, self.codebook_size, self.epsilon)
228
- * self.cluster_size.sum()
229
- )
230
- embed_normalized = self.embed_avg / cluster_size.unsqueeze(1)
231
- self.embed.data.copy_(embed_normalized)
232
-
233
- return quantize, embed_ind
234
-
235
-
236
- class VectorQuantization(nn.Module):
237
- """Vector quantization implementation.
238
- Currently supports only euclidean distance.
239
- Args:
240
- dim (int): Dimension
241
- codebook_size (int): Codebook size
242
- codebook_dim (int): Codebook dimension. If not defined, uses the specified dimension in dim.
243
- decay (float): Decay for exponential moving average over the codebooks.
244
- epsilon (float): Epsilon value for numerical stability.
245
- kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
246
- kmeans_iters (int): Number of iterations used for kmeans initialization.
247
- threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
248
- that have an exponential moving average cluster size less than the specified threshold with
249
- randomly selected vector from the current batch.
250
- commitment_weight (float): Weight for commitment loss.
251
- """
252
- def __init__(
253
- self,
254
- dim: int,
255
- codebook_size: int,
256
- codebook_dim: tp.Optional[int] = None,
257
- decay: float = 0.99,
258
- epsilon: float = 1e-5,
259
- kmeans_init: bool = True,
260
- kmeans_iters: int = 50,
261
- threshold_ema_dead_code: int = 2,
262
- commitment_weight: float = 1.,
263
- ):
264
- super().__init__()
265
- _codebook_dim: int = default(codebook_dim, dim)
266
-
267
- requires_projection = _codebook_dim != dim
268
- self.project_in = (nn.Linear(dim, _codebook_dim) if requires_projection else nn.Identity())
269
- self.project_out = (nn.Linear(_codebook_dim, dim) if requires_projection else nn.Identity())
270
-
271
- self.epsilon = epsilon
272
- self.commitment_weight = commitment_weight
273
-
274
- self._codebook = EuclideanCodebook(dim=_codebook_dim, codebook_size=codebook_size,
275
- kmeans_init=kmeans_init, kmeans_iters=kmeans_iters,
276
- decay=decay, epsilon=epsilon,
277
- threshold_ema_dead_code=threshold_ema_dead_code)
278
- self.codebook_size = codebook_size
279
-
280
- @property
281
- def codebook(self):
282
- return self._codebook.embed
283
-
284
- def encode(self, x):
285
- x = rearrange(x, "b d n -> b n d")
286
- x = self.project_in(x)
287
- embed_in = self._codebook.encode(x)
288
- return embed_in
289
-
290
- def decode(self, embed_ind):
291
- quantize = self._codebook.decode(embed_ind)
292
- quantize = self.project_out(quantize)
293
- quantize = rearrange(quantize, "b n d -> b d n")
294
- return quantize
295
-
296
- def forward(self, x, do_debug=False):
297
- device = x.device
298
- x = rearrange(x, "b d n -> b n d")
299
- x = self.project_in(x)
300
-
301
- quantize, embed_ind = self._codebook(x)
302
-
303
- if self.training:
304
- quantize = x + (quantize - x).detach()
305
-
306
- loss = torch.tensor([0.0], device=device, requires_grad=self.training)
307
-
308
- if self.training:
309
- if self.commitment_weight > 0:
310
- commit_loss = F.mse_loss(quantize.detach(), x)
311
- loss = loss + commit_loss * self.commitment_weight
312
- quantize = self.project_out(quantize)
313
- quantize = rearrange(quantize, "b n d -> b d n")
314
- return quantize, embed_ind, loss
315
-
316
-
317
- class ResidualVectorQuantization(nn.Module):
318
- """Residual vector quantization implementation.
319
- Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf
320
- """
321
- def __init__(self, *, num_quantizers, **kwargs):
322
- super().__init__()
323
- self.layers = nn.ModuleList(
324
- [VectorQuantization(**kwargs) for _ in range(num_quantizers)]
325
- )
326
-
327
- def forward(self, x, n_q: tp.Optional[int] = None):
328
- quantized_out = 0.0
329
- residual = x
330
-
331
- all_losses = []
332
- all_indices = []
333
-
334
- n_q = n_q or len(self.layers)
335
-
336
- for layerinx, layer in enumerate(self.layers[:n_q]):
337
- print("Layer {} Used ratio {:.1f}".format(layerinx, (layer._codebook.cluster_size > 1.0).sum() / layer._codebook.cluster_size.shape[0] * 100.))
338
- quantized, indices, loss = layer(residual)
339
- residual = residual - quantized
340
- quantized_out = quantized_out + quantized
341
-
342
- all_indices.append(indices)
343
- all_losses.append(loss)
344
-
345
- out_losses, out_indices = map(torch.stack, (all_losses, all_indices))
346
- return quantized_out, out_indices, out_losses
347
-
348
- def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None) -> torch.Tensor:
349
- residual = x
350
- all_indices = []
351
- n_q = n_q or len(self.layers)
352
- for layer in self.layers[:n_q]:
353
- indices = layer.encode(residual)
354
- quantized = layer.decode(indices)
355
- residual = residual - quantized
356
- all_indices.append(indices)
357
- out_indices = torch.stack(all_indices)
358
- return out_indices
359
-
360
- def decode(self, q_indices: torch.Tensor) -> torch.Tensor:
361
- quantized_out = torch.tensor(0.0, device=q_indices.device)
362
- for i, indices in enumerate(q_indices):
363
- layer = self.layers[i]
364
- quantized = layer.decode(indices)
365
- quantized_out = quantized_out + quantized
366
- return quantized_out
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ #
7
+ # This implementation is inspired from
8
+ # https://github.com/lucidrains/vector-quantize-pytorch
9
+ # which is released under MIT License. Hereafter, the original license:
10
+ # MIT License
11
+ #
12
+ # Copyright (c) 2020 Phil Wang
13
+ #
14
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
15
+ # of this software and associated documentation files (the "Software"), to deal
16
+ # in the Software without restriction, including without limitation the rights
17
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
18
+ # copies of the Software, and to permit persons to whom the Software is
19
+ # furnished to do so, subject to the following conditions:
20
+ #
21
+ # The above copyright notice and this permission notice shall be included in all
22
+ # copies or substantial portions of the Software.
23
+ #
24
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
25
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
26
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
27
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
28
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
29
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
30
+ # SOFTWARE.
31
+
32
+ """Core vector quantization implementation."""
33
+
34
+ import typing as tp
35
+
36
+ from einops import rearrange, repeat
37
+ import torch
38
+ from torch import nn
39
+ import torch.nn.functional as F
40
+
41
+ # from .. import distrib
42
+
43
+
44
+ def default(val: tp.Any, d: tp.Any) -> tp.Any:
45
+ return val if val is not None else d
46
+
47
+
48
+ def ema_inplace(moving_avg, new, decay: float):
49
+ moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
50
+
51
+
52
+ def laplace_smoothing(x, n_categories: int, epsilon: float = 1e-5):
53
+ return (x + epsilon) / (x.sum() + n_categories * epsilon)
54
+
55
+
56
+ def uniform_init(*shape: int):
57
+ t = torch.empty(shape)
58
+ nn.init.kaiming_uniform_(t)
59
+ return t
60
+
61
+
62
+ def sample_vectors(samples, num: int):
63
+ num_samples, device = samples.shape[0], samples.device
64
+
65
+ if num_samples >= num:
66
+ indices = torch.randperm(num_samples, device=device)[:num]
67
+ else:
68
+ indices = torch.randint(0, num_samples, (num,), device=device)
69
+
70
+ return samples[indices]
71
+
72
+
73
+ def kmeans(samples, num_clusters: int, num_iters: int = 10):
74
+ dim, dtype = samples.shape[-1], samples.dtype
75
+
76
+ means = sample_vectors(samples, num_clusters)
77
+
78
+ for _ in range(num_iters):
79
+ diffs = rearrange(samples, "n d -> n () d") - rearrange(
80
+ means, "c d -> () c d"
81
+ )
82
+ dists = -(diffs ** 2).sum(dim=-1)
83
+
84
+ buckets = dists.max(dim=-1).indices
85
+ bins = torch.bincount(buckets, minlength=num_clusters)
86
+ zero_mask = bins == 0
87
+ bins_min_clamped = bins.masked_fill(zero_mask, 1)
88
+
89
+ new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype)
90
+ new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=dim), samples)
91
+ new_means = new_means / bins_min_clamped[..., None]
92
+
93
+ means = torch.where(zero_mask[..., None], means, new_means)
94
+
95
+ return means, bins
96
+
97
+
98
+ class EuclideanCodebook(nn.Module):
99
+ """Codebook with Euclidean distance.
100
+ Args:
101
+ dim (int): Dimension.
102
+ codebook_size (int): Codebook size.
103
+ kmeans_init (bool): Whether to use k-means to initialize the codebooks.
104
+ If set to true, run the k-means algorithm on the first training batch and use
105
+ the learned centroids as initialization.
106
+ kmeans_iters (int): Number of iterations used for k-means algorithm at initialization.
107
+ decay (float): Decay for exponential moving average over the codebooks.
108
+ epsilon (float): Epsilon value for numerical stability.
109
+ threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
110
+ that have an exponential moving average cluster size less than the specified threshold with
111
+ randomly selected vector from the current batch.
112
+ """
113
+ def __init__(
114
+ self,
115
+ dim: int,
116
+ codebook_size: int,
117
+ kmeans_init: int = False,
118
+ kmeans_iters: int = 10,
119
+ decay: float = 0.99,
120
+ epsilon: float = 1e-5,
121
+ threshold_ema_dead_code: int = 2,
122
+ ):
123
+ super().__init__()
124
+ self.decay = decay
125
+ init_fn: tp.Union[tp.Callable[..., torch.Tensor], tp.Any] = uniform_init if not kmeans_init else torch.zeros
126
+ embed = init_fn(codebook_size, dim)
127
+
128
+ self.codebook_size = codebook_size
129
+
130
+ self.kmeans_iters = kmeans_iters
131
+ self.epsilon = epsilon
132
+ self.threshold_ema_dead_code = threshold_ema_dead_code
133
+
134
+ self.register_buffer("inited", torch.Tensor([not kmeans_init]))
135
+ self.register_buffer("cluster_size", torch.zeros(codebook_size))
136
+ self.register_buffer("embed", embed)
137
+ self.register_buffer("embed_avg", embed.clone())
138
+
139
+ self.runed_steps = 0
140
+ self.stop_steps = 50_000
141
+
142
+ @torch.jit.ignore
143
+ def init_embed_(self, data):
144
+ if self.inited:
145
+ return
146
+
147
+ embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters)
148
+ self.embed.data.copy_(embed)
149
+ self.embed_avg.data.copy_(embed.clone())
150
+ self.cluster_size.data.copy_(cluster_size)
151
+ self.inited.data.copy_(torch.Tensor([True]))
152
+ # Make sure all buffers across workers are in sync after initialization
153
+ distrib.broadcast_tensors(self.buffers())
154
+
155
+ def replace_(self, samples, mask):
156
+ modified_codebook = torch.where(
157
+ mask[..., None], sample_vectors(samples, self.codebook_size), self.embed
158
+ )
159
+ self.embed.data.copy_(modified_codebook)
160
+
161
+ def expire_codes_(self, batch_samples):
162
+ if self.threshold_ema_dead_code == 0:
163
+ return
164
+
165
+ expired_codes = self.cluster_size < self.threshold_ema_dead_code
166
+ if not torch.any(expired_codes):
167
+ return
168
+
169
+ batch_samples = rearrange(batch_samples, "... d -> (...) d")
170
+ self.replace_(batch_samples, mask=expired_codes)
171
+ # distrib.broadcast_tensors(self.buffers())
172
+
173
+ def preprocess(self, x):
174
+ x = rearrange(x, "... d -> (...) d")
175
+ return x
176
+
177
+ def quantize(self, x):
178
+ embed = self.embed.t()
179
+ dist = -(
180
+ x.pow(2).sum(1, keepdim=True)
181
+ - 2 * x @ embed
182
+ + embed.pow(2).sum(0, keepdim=True)
183
+ )
184
+ embed_ind = dist.max(dim=-1).indices
185
+ return embed_ind
186
+
187
+ def postprocess_emb(self, embed_ind, shape):
188
+ return embed_ind.view(*shape[:-1])
189
+
190
+ def dequantize(self, embed_ind):
191
+ quantize = F.embedding(embed_ind, self.embed)
192
+ return quantize
193
+
194
+ def encode(self, x):
195
+ shape = x.shape
196
+ # pre-process
197
+ x = self.preprocess(x)
198
+ # quantize
199
+ embed_ind = self.quantize(x)
200
+ # post-process
201
+ embed_ind = self.postprocess_emb(embed_ind, shape)
202
+ return embed_ind
203
+
204
+ def decode(self, embed_ind):
205
+ quantize = self.dequantize(embed_ind)
206
+ return quantize
207
+
208
+ def forward(self, x):
209
+ shape, dtype = x.shape, x.dtype
210
+ x = self.preprocess(x)
211
+ # self.init_embed_(x)
212
+
213
+ embed_ind = self.quantize(x)
214
+ embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype)
215
+ embed_ind = self.postprocess_emb(embed_ind, shape)
216
+ quantize = self.dequantize(embed_ind)
217
+ self.runed_steps += 1
218
+
219
+ if self.training and self.runed_steps < self.stop_steps:
220
+ # We do the expiry of code at that point as buffers are in sync
221
+ # and all the workers will take the same decision.
222
+ self.expire_codes_(x)
223
+ ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay)
224
+ embed_sum = x.t() @ embed_onehot
225
+ ema_inplace(self.embed_avg, embed_sum.t(), self.decay)
226
+ cluster_size = (
227
+ laplace_smoothing(self.cluster_size, self.codebook_size, self.epsilon)
228
+ * self.cluster_size.sum()
229
+ )
230
+ embed_normalized = self.embed_avg / cluster_size.unsqueeze(1)
231
+ self.embed.data.copy_(embed_normalized)
232
+
233
+ return quantize, embed_ind
234
+
235
+
236
+ class VectorQuantization(nn.Module):
237
+ """Vector quantization implementation.
238
+ Currently supports only euclidean distance.
239
+ Args:
240
+ dim (int): Dimension
241
+ codebook_size (int): Codebook size
242
+ codebook_dim (int): Codebook dimension. If not defined, uses the specified dimension in dim.
243
+ decay (float): Decay for exponential moving average over the codebooks.
244
+ epsilon (float): Epsilon value for numerical stability.
245
+ kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
246
+ kmeans_iters (int): Number of iterations used for kmeans initialization.
247
+ threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
248
+ that have an exponential moving average cluster size less than the specified threshold with
249
+ randomly selected vector from the current batch.
250
+ commitment_weight (float): Weight for commitment loss.
251
+ """
252
+ def __init__(
253
+ self,
254
+ dim: int,
255
+ codebook_size: int,
256
+ codebook_dim: tp.Optional[int] = None,
257
+ decay: float = 0.99,
258
+ epsilon: float = 1e-5,
259
+ kmeans_init: bool = True,
260
+ kmeans_iters: int = 50,
261
+ threshold_ema_dead_code: int = 2,
262
+ commitment_weight: float = 1.,
263
+ ):
264
+ super().__init__()
265
+ _codebook_dim: int = default(codebook_dim, dim)
266
+
267
+ requires_projection = _codebook_dim != dim
268
+ self.project_in = (nn.Linear(dim, _codebook_dim) if requires_projection else nn.Identity())
269
+ self.project_out = (nn.Linear(_codebook_dim, dim) if requires_projection else nn.Identity())
270
+
271
+ self.epsilon = epsilon
272
+ self.commitment_weight = commitment_weight
273
+
274
+ self._codebook = EuclideanCodebook(dim=_codebook_dim, codebook_size=codebook_size,
275
+ kmeans_init=kmeans_init, kmeans_iters=kmeans_iters,
276
+ decay=decay, epsilon=epsilon,
277
+ threshold_ema_dead_code=threshold_ema_dead_code)
278
+ self.codebook_size = codebook_size
279
+
280
+ @property
281
+ def codebook(self):
282
+ return self._codebook.embed
283
+
284
+ def encode(self, x):
285
+ x = rearrange(x, "b d n -> b n d")
286
+ x = self.project_in(x)
287
+ embed_in = self._codebook.encode(x)
288
+ return embed_in
289
+
290
+ def decode(self, embed_ind):
291
+ quantize = self._codebook.decode(embed_ind)
292
+ quantize = self.project_out(quantize)
293
+ quantize = rearrange(quantize, "b n d -> b d n")
294
+ return quantize
295
+
296
+ def forward(self, x, do_debug=False):
297
+ device = x.device
298
+ x = rearrange(x, "b d n -> b n d")
299
+ x = self.project_in(x)
300
+
301
+ quantize, embed_ind = self._codebook(x)
302
+
303
+ if self.training:
304
+ quantize = x + (quantize - x).detach()
305
+
306
+ loss = torch.tensor([0.0], device=device, requires_grad=self.training)
307
+
308
+ if self.training:
309
+ if self.commitment_weight > 0:
310
+ commit_loss = F.mse_loss(quantize.detach(), x)
311
+ loss = loss + commit_loss * self.commitment_weight
312
+ quantize = self.project_out(quantize)
313
+ quantize = rearrange(quantize, "b n d -> b d n")
314
+ return quantize, embed_ind, loss
315
+
316
+
317
+ class ResidualVectorQuantization(nn.Module):
318
+ """Residual vector quantization implementation.
319
+ Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf
320
+ """
321
+ def __init__(self, *, num_quantizers, **kwargs):
322
+ super().__init__()
323
+ self.layers = nn.ModuleList(
324
+ [VectorQuantization(**kwargs) for _ in range(num_quantizers)]
325
+ )
326
+
327
+ def forward(self, x, n_q: tp.Optional[int] = None):
328
+ quantized_out = 0.0
329
+ residual = x
330
+
331
+ all_losses = []
332
+ all_indices = []
333
+
334
+ n_q = n_q or len(self.layers)
335
+
336
+ for layerinx, layer in enumerate(self.layers[:n_q]):
337
+ print("Layer {} Used ratio {:.1f}".format(layerinx, (layer._codebook.cluster_size > 1.0).sum() / layer._codebook.cluster_size.shape[0] * 100.))
338
+ quantized, indices, loss = layer(residual)
339
+ residual = residual - quantized
340
+ quantized_out = quantized_out + quantized
341
+
342
+ all_indices.append(indices)
343
+ all_losses.append(loss)
344
+
345
+ out_losses, out_indices = map(torch.stack, (all_losses, all_indices))
346
+ return quantized_out, out_indices, out_losses
347
+
348
+ def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None) -> torch.Tensor:
349
+ residual = x
350
+ all_indices = []
351
+ n_q = n_q or len(self.layers)
352
+ for layer in self.layers[:n_q]:
353
+ indices = layer.encode(residual)
354
+ quantized = layer.decode(indices)
355
+ residual = residual - quantized
356
+ all_indices.append(indices)
357
+ out_indices = torch.stack(all_indices)
358
+ return out_indices
359
+
360
+ def decode(self, q_indices: torch.Tensor) -> torch.Tensor:
361
+ quantized_out = torch.tensor(0.0, device=q_indices.device)
362
+ for i, indices in enumerate(q_indices):
363
+ layer = self.layers[i]
364
+ quantized = layer.decode(indices)
365
+ quantized_out = quantized_out + quantized
366
+ return quantized_out
codeclm/tokenizer/Flow1dVAE/model_1rvq.py CHANGED
@@ -1,710 +1,710 @@
1
- import yaml
2
- import random
3
- import inspect
4
- import numpy as np
5
- from tqdm import tqdm
6
- import typing as tp
7
- from abc import ABC
8
-
9
- import torch
10
- import torch.nn as nn
11
- import torch.nn.functional as F
12
- import torchaudio
13
-
14
- from tools.torch_tools import wav_to_fbank
15
-
16
- from diffusers.utils.torch_utils import randn_tensor
17
- from transformers import HubertModel
18
- from libs.rvq.descript_quantize3 import ResidualVectorQuantize
19
-
20
- from models_gpt.models.gpt2_rope2_time_new_correct_mask_noncasual_reflow import GPT2Model
21
- from models_gpt.models.gpt2_config import GPT2Config
22
-
23
- from torch.cuda.amp import autocast
24
-
25
-
26
- from our_MERT_BESTRQ.test import load_model
27
-
28
- class HubertModelWithFinalProj(HubertModel):
29
- def __init__(self, config):
30
- super().__init__(config)
31
-
32
- # The final projection layer is only used for backward compatibility.
33
- # Following https://github.com/auspicious3000/contentvec/issues/6
34
- # Remove this layer is necessary to achieve the desired outcome.
35
- print("hidden_size:",config.hidden_size)
36
- print("classifier_proj_size:",config.classifier_proj_size)
37
- self.final_proj = nn.Linear(config.hidden_size, config.classifier_proj_size)
38
-
39
-
40
- class SampleProcessor(torch.nn.Module):
41
- def project_sample(self, x: torch.Tensor):
42
- """Project the original sample to the 'space' where the diffusion will happen."""
43
- """Project back from diffusion space to the actual sample space."""
44
- return z
45
-
46
- class Feature1DProcessor(SampleProcessor):
47
- def __init__(self, dim: int = 100, power_std = 1., \
48
- num_samples: int = 100_000, cal_num_frames: int = 600):
49
- super().__init__()
50
-
51
- self.num_samples = num_samples
52
- self.dim = dim
53
- self.power_std = power_std
54
- self.cal_num_frames = cal_num_frames
55
- self.register_buffer('counts', torch.zeros(1))
56
- self.register_buffer('sum_x', torch.zeros(dim))
57
- self.register_buffer('sum_x2', torch.zeros(dim))
58
- self.register_buffer('sum_target_x2', torch.zeros(dim))
59
- self.counts: torch.Tensor
60
- self.sum_x: torch.Tensor
61
- self.sum_x2: torch.Tensor
62
-
63
- @property
64
- def mean(self):
65
- mean = self.sum_x / self.counts
66
- if(self.counts < 10):
67
- mean = torch.zeros_like(mean)
68
- return mean
69
-
70
- @property
71
- def std(self):
72
- std = (self.sum_x2 / self.counts - self.mean**2).clamp(min=0).sqrt()
73
- if(self.counts < 10):
74
- std = torch.ones_like(std)
75
- return std
76
-
77
- @property
78
- def target_std(self):
79
- return 1
80
-
81
- def project_sample(self, x: torch.Tensor):
82
- assert x.dim() == 3
83
- if self.counts.item() < self.num_samples:
84
- self.counts += len(x)
85
- self.sum_x += x[:,:,0:self.cal_num_frames].mean(dim=(2,)).sum(dim=0)
86
- self.sum_x2 += x[:,:,0:self.cal_num_frames].pow(2).mean(dim=(2,)).sum(dim=0)
87
- rescale = (self.target_std / self.std.clamp(min=1e-12)) ** self.power_std # same output size
88
- x = (x - self.mean.view(1, -1, 1)) * rescale.view(1, -1, 1)
89
- return x
90
-
91
- def return_sample(self, x: torch.Tensor):
92
- assert x.dim() == 3
93
- rescale = (self.std / self.target_std) ** self.power_std
94
- # print(rescale, self.mean)
95
- x = x * rescale.view(1, -1, 1) + self.mean.view(1, -1, 1)
96
- return x
97
-
98
- def pad_or_tunc_tolen(prior_text_encoder_hidden_states, prior_text_mask, prior_prompt_embeds, len_size=77):
99
- if(prior_text_encoder_hidden_states.shape[1]<len_size):
100
- prior_text_encoder_hidden_states = torch.cat([prior_text_encoder_hidden_states, \
101
- torch.zeros(prior_text_mask.shape[0], len_size-prior_text_mask.shape[1], \
102
- prior_text_encoder_hidden_states.shape[2], device=prior_text_mask.device, \
103
- dtype=prior_text_encoder_hidden_states.dtype)],1)
104
- prior_text_mask = torch.cat([prior_text_mask, torch.zeros(prior_text_mask.shape[0], len_size-prior_text_mask.shape[1], device=prior_text_mask.device, dtype=prior_text_mask.dtype)],1)
105
- else:
106
- prior_text_encoder_hidden_states = prior_text_encoder_hidden_states[:,0:len_size]
107
- prior_text_mask = prior_text_mask[:,0:len_size]
108
- prior_text_encoder_hidden_states = prior_text_encoder_hidden_states.permute(0,2,1).contiguous()
109
- return prior_text_encoder_hidden_states, prior_text_mask, prior_prompt_embeds
110
-
111
- class BASECFM(torch.nn.Module, ABC):
112
- def __init__(
113
- self,
114
- estimator,
115
- mlp,
116
- ssl_layer
117
- ):
118
- super().__init__()
119
- self.sigma_min = 1e-4
120
-
121
- self.estimator = estimator
122
- self.mlp = mlp
123
- self.ssl_layer = ssl_layer
124
-
125
- @torch.inference_mode()
126
- def forward(self, mu, n_timesteps, temperature=1.0):
127
- """Forward diffusion
128
-
129
- Args:
130
- mu (torch.Tensor): output of encoder
131
- shape: (batch_size, n_channels, mel_timesteps, n_feats)
132
- n_timesteps (int): number of diffusion steps
133
- temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
134
-
135
- Returns:
136
- sample: generated mel-spectrogram
137
- shape: (batch_size, n_channels, mel_timesteps, n_feats)
138
- """
139
- z = torch.randn_like(mu) * temperature
140
- t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device)
141
- return self.solve_euler(z, t_span=t_span)
142
-
143
- def solve_euler(self, x, latent_mask_input,incontext_x, incontext_length, t_span, mu,attention_mask, guidance_scale):
144
- """
145
- Fixed euler solver for ODEs.
146
- Args:
147
- x (torch.Tensor): random noise
148
- t_span (torch.Tensor): n_timesteps interpolated
149
- shape: (n_timesteps + 1,)
150
- mu (torch.Tensor): output of encoder
151
- shape: (batch_size, n_channels, mel_timesteps, n_feats)
152
- """
153
- t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
154
- noise = x.clone()
155
-
156
- # I am storing this because I can later plot it by putting a debugger here and saving it to a file
157
- # Or in future might add like a return_all_steps flag
158
- sol = []
159
-
160
- for step in tqdm(range(1, len(t_span))):
161
- # print("incontext_x.shape:",incontext_x.shape)
162
- # print("noise.shape:",noise.shape)
163
- # print("t.shape:",t.shape)
164
- x[:,0:incontext_length,:] = (1 - (1 - self.sigma_min) * t) * noise[:,0:incontext_length,:] + t * incontext_x[:,0:incontext_length,:]
165
- if(guidance_scale > 1.0):
166
-
167
- model_input = torch.cat([ \
168
- torch.cat([latent_mask_input, latent_mask_input], 0), \
169
- torch.cat([incontext_x, incontext_x], 0), \
170
- torch.cat([torch.zeros_like(mu), mu], 0), \
171
- torch.cat([x, x], 0), \
172
- ], 2)
173
- timestep=t.unsqueeze(-1).repeat(2)
174
-
175
- dphi_dt = self.estimator(inputs_embeds=model_input, attention_mask=attention_mask,time_step=timestep).last_hidden_state
176
- dphi_dt_uncond, dhpi_dt_cond = dphi_dt.chunk(2,0)
177
- dphi_dt = dphi_dt_uncond + guidance_scale * (dhpi_dt_cond - dphi_dt_uncond)
178
- else:
179
- model_input = torch.cat([latent_mask_input, incontext_x, mu, x], 2)
180
- timestep=t.unsqueeze(-1)
181
- dphi_dt = self.estimator(inputs_embeds=model_input, attention_mask=attention_mask,time_step=timestep).last_hidden_state
182
-
183
- dphi_dt = dphi_dt[: ,:, -x.shape[2]:]
184
- # print("dphi_dt.shape:",dphi_dt.shape)
185
- # print("x.shape:",x.shape)
186
-
187
- x = x + dt * dphi_dt
188
- t = t + dt
189
- sol.append(x)
190
- if step < len(t_span) - 1:
191
- dt = t_span[step + 1] - t
192
-
193
- return sol[-1]
194
-
195
- def projection_loss(self,hidden_proj, bestrq_emb):
196
- bsz = hidden_proj.shape[0]
197
-
198
- hidden_proj_normalized = F.normalize(hidden_proj, dim=-1)
199
- bestrq_emb_normalized = F.normalize(bestrq_emb, dim=-1)
200
-
201
- proj_loss = -(hidden_proj_normalized * bestrq_emb_normalized).sum(dim=-1)
202
- proj_loss = 1+proj_loss.mean()
203
-
204
- return proj_loss
205
-
206
- def compute_loss(self, x1, mu, latent_masks,attention_mask,wav2vec_embeds, validation_mode=False):
207
- """Computes diffusion loss
208
-
209
- Args:
210
- x1 (torch.Tensor): Target
211
- shape: (batch_size, n_channels, mel_timesteps, n_feats)
212
- mu (torch.Tensor): output of encoder
213
- shape: (batch_size, n_channels, mel_timesteps, n_feats)
214
-
215
- Returns:
216
- loss: conditional flow matching loss
217
- y: conditional flow
218
- shape: (batch_size, n_channels, mel_timesteps, n_feats)
219
- """
220
- b = mu[0].shape[0]
221
- len_x = x1.shape[2]
222
- # random timestep
223
- if(validation_mode):
224
- t = torch.ones([b, 1, 1], device=mu[0].device, dtype=mu[0].dtype) * 0.5
225
- else:
226
- t = torch.rand([b, 1, 1], device=mu[0].device, dtype=mu[0].dtype)
227
- # sample noise p(x_0)
228
- z = torch.randn_like(x1)
229
-
230
- y = (1 - (1 - self.sigma_min) * t) * z + t * x1
231
- u = x1 - (1 - self.sigma_min) * z
232
- # print("y.shape:",y.shape)
233
- #self.unet(inputs_embeds=model_input, attention_mask=attention_mask,encoder_hidden_states=text_embedding,encoder_attention_mask=txt_attn_mask,time_step=timesteps).last_hidden_state
234
- model_input = torch.cat([*mu,y], 2)
235
- t=t.squeeze(-1).squeeze(-1)
236
- # print("model_input.shape:",model_input.shape)
237
- # print("attention_mask.shape:",attention_mask.shape)
238
- out = self.estimator(inputs_embeds=model_input, attention_mask=attention_mask,time_step=t,output_hidden_states=True)
239
- hidden_layer = out.hidden_states[self.ssl_layer]
240
- hidden_proj = self.mlp(hidden_layer)
241
- # print("hidden_proj.shape:",hidden_proj.shape)
242
- # print("mert_emb.shape:",mert_emb.shape)
243
- # exit()
244
-
245
-
246
- out = out.last_hidden_state
247
-
248
- out=out[:,:,-len_x:]
249
- # out=self.proj_out(out)
250
-
251
- weight = (latent_masks > 1.5).unsqueeze(-1).repeat(1, 1, out.shape[-1]).float() + (latent_masks < 0.5).unsqueeze(-1).repeat(1, 1, out.shape[-1]).float() * 0.01
252
- # print("out.shape",out.shape)
253
- # print("u.shape",u.shape)
254
- loss_re = F.mse_loss(out * weight, u * weight, reduction="sum") / weight.sum()
255
- # print("hidden_proj.shape:",hidden_proj.shape)
256
- # print("wav2vec_embeds.shape:",wav2vec_embeds.shape)
257
- loss_cos = self.projection_loss(hidden_proj, wav2vec_embeds)
258
- loss = loss_re + loss_cos * 0.5
259
- # print("loss_cos:",loss_cos,loss_cos.device)
260
- print("loss:",loss,loss.device)
261
- # exit()
262
- return loss, loss_re, loss_cos
263
-
264
- class PromptCondAudioDiffusion(nn.Module):
265
- def __init__(
266
- self,
267
- num_channels,
268
- unet_model_name=None,
269
- unet_model_config_path=None,
270
- snr_gamma=None,
271
- hubert_layer=None,
272
- ssl_layer=None,
273
- uncondition=True,
274
- out_paint=False,
275
- ):
276
- super().__init__()
277
-
278
- assert unet_model_name is not None or unet_model_config_path is not None, "Either UNet pretrain model name or a config file path is required"
279
-
280
- self.unet_model_name = unet_model_name
281
- self.unet_model_config_path = unet_model_config_path
282
- self.snr_gamma = snr_gamma
283
- self.uncondition = uncondition
284
- self.num_channels = num_channels
285
- self.hubert_layer = hubert_layer
286
- self.ssl_layer = ssl_layer
287
-
288
- # https://huggingface.co/docs/diffusers/v0.14.0/en/api/schedulers/overview
289
- self.normfeat = Feature1DProcessor(dim=64)
290
-
291
- self.sample_rate = 48000
292
- self.num_samples_perseg = self.sample_rate * 20 // 1000
293
- self.rsp48toclap = torchaudio.transforms.Resample(48000, 24000)
294
- self.rsq48towav2vec = torchaudio.transforms.Resample(48000, 16000)
295
- # self.wav2vec = Wav2Vec2BertModel.from_pretrained("facebook/w2v-bert-2.0", trust_remote_code=True)
296
- # self.wav2vec_processor = AutoFeatureExtractor.from_pretrained("facebook/w2v-bert-2.0", trust_remote_code=True)
297
- self.bestrq = load_model(
298
- model_dir='codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq',
299
- checkpoint_dir='ckpt/encode-s12k.pt',
300
- )
301
- self.rsq48tobestrq = torchaudio.transforms.Resample(48000, 24000)
302
- self.rsq48tohubert = torchaudio.transforms.Resample(48000, 16000)
303
- for v in self.bestrq.parameters():v.requires_grad = False
304
- self.rvq_bestrq_emb = ResidualVectorQuantize(input_dim = 1024, n_codebooks = 1, codebook_size = 16_384, codebook_dim = 32, quantizer_dropout = 0.0, stale_tolerance=200)
305
- for v in self.rvq_bestrq_emb.parameters():v.requires_grad = False
306
- self.hubert = HubertModelWithFinalProj.from_pretrained("ckpt/models--lengyue233--content-vec-best/snapshots/c0b9ba13db21beaa4053faae94c102ebe326fd68")
307
- for v in self.hubert.parameters():v.requires_grad = False
308
- self.zero_cond_embedding1 = nn.Parameter(torch.randn(32*32,))
309
- # self.xvecmodel = XVECModel()
310
- config = GPT2Config(n_positions=1000,n_layer=39,n_head=30,n_embd=1200)
311
- unet = GPT2Model(config)
312
- mlp = nn.Sequential(
313
- nn.Linear(1200, 1024),
314
- nn.SiLU(),
315
- nn.Linear(1024, 1024),
316
- nn.SiLU(),
317
- nn.Linear(1024, 768)
318
- )
319
- self.set_from = "random"
320
- self.cfm_wrapper = BASECFM(unet, mlp,self.ssl_layer)
321
- self.mask_emb = torch.nn.Embedding(3, 48)
322
- print("Transformer initialized from pretrain.")
323
- torch.cuda.empty_cache()
324
- # self.unet.set_attn_processor(AttnProcessor2_0())
325
- # self.unet.set_use_memory_efficient_attention_xformers(True)
326
-
327
- # self.start_embedding = nn.Parameter(torch.randn(1,1024))
328
- # self.end_embedding = nn.Parameter(torch.randn(1,1024))
329
-
330
- def compute_snr(self, timesteps):
331
- """
332
- Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
333
- """
334
- alphas_cumprod = self.noise_scheduler.alphas_cumprod
335
- sqrt_alphas_cumprod = alphas_cumprod**0.5
336
- sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
337
-
338
- # Expand the tensors.
339
- # Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026
340
- sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
341
- while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape):
342
- sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
343
- alpha = sqrt_alphas_cumprod.expand(timesteps.shape)
344
-
345
- sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
346
- while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape):
347
- sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None]
348
- sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape)
349
-
350
- # Compute SNR.
351
- snr = (alpha / sigma) ** 2
352
- return snr
353
-
354
- def preprocess_audio(self, input_audios, threshold=0.9):
355
- assert len(input_audios.shape) == 2, input_audios.shape
356
- norm_value = torch.ones_like(input_audios[:,0])
357
- max_volume = input_audios.abs().max(dim=-1)[0]
358
- norm_value[max_volume>threshold] = max_volume[max_volume>threshold] / threshold
359
- return input_audios/norm_value.unsqueeze(-1)
360
-
361
- def extract_wav2vec_embeds(self, input_audios,output_len):
362
- wav2vec_stride = 2
363
-
364
- wav2vec_embeds = self.hubert(self.rsq48tohubert(input_audios), output_hidden_states=True).hidden_states # 1, 4096, 1024
365
- # print(wav2vec_embeds)
366
- # print("audio.shape:",input_audios.shape)
367
- wav2vec_embeds_last=wav2vec_embeds[self.hubert_layer]
368
- # print("wav2vec_embeds_last.shape:",wav2vec_embeds_last.shape)
369
- wav2vec_embeds_last=torch.nn.functional.interpolate(wav2vec_embeds_last.permute(0, 2, 1), size=output_len, mode='linear', align_corners=False).permute(0, 2, 1)
370
- return wav2vec_embeds_last
371
-
372
- def extract_mert_embeds(self, input_audios):
373
- prompt_stride = 3
374
- inputs = self.clap_embd_extractor.mulan.audio.processor(self.rsp48toclap(input_audios), sampling_rate=self.clap_embd_extractor.mulan.audio.sr, return_tensors="pt")
375
- input_values = inputs['input_values'].squeeze(0).to(input_audios.device, dtype = input_audios.dtype)
376
- prompt_embeds = self.clap_embd_extractor.mulan.audio.model(input_values, output_hidden_states=True).hidden_states # batch_size, Time steps, 1024
377
- mert_emb= prompt_embeds[-1]
378
- mert_emb = torch.nn.functional.interpolate(mert_emb.permute(0, 2, 1), size=500, mode='linear', align_corners=False).permute(0, 2, 1)
379
-
380
- return mert_emb
381
-
382
- def extract_bestrq_embeds(self, input_audio_0,input_audio_1,layer):
383
- self.bestrq.eval()
384
- # print("audio shape:",input_audio_0.shape)
385
- input_wav_mean = (input_audio_0 + input_audio_1) / 2.0
386
- # print("input_wav_mean.shape:",input_wav_mean.shape)
387
- # input_wav_mean = torch.randn(2,1720320*2).to(input_audio_0.device)
388
- input_wav_mean = self.bestrq(self.rsq48tobestrq(input_wav_mean), features_only = True)
389
- layer_results = input_wav_mean['layer_results']
390
- # print("layer_results.shape:",layer_results[layer].shape)
391
- bestrq_emb = layer_results[layer]
392
- bestrq_emb = bestrq_emb.permute(0,2,1).contiguous()
393
- #[b,t,1024] t=t/960
394
- #35.84s->batch,896,1024
395
- return bestrq_emb
396
-
397
-
398
- def extract_spk_embeds(self, input_audios):
399
- spk_embeds = self.xvecmodel(self.rsq48towav2vec(input_audios))
400
- spk_embeds = self.spk_linear(spk_embeds).reshape(spk_embeds.shape[0], 16, 1, 32)
401
- return spk_embeds
402
-
403
- def extract_lyric_feats(self, lyric):
404
- with torch.no_grad():
405
- try:
406
- text_encoder_hidden_states, text_mask, text_prompt_embeds = self.clap_embd_extractor(texts = lyric, return_one=False)
407
- except:
408
- text_encoder_hidden_states, text_mask, text_prompt_embeds = self.clap_embd_extractor(texts = [""] * len(lyric), return_one=False)
409
- text_encoder_hidden_states = text_encoder_hidden_states.to(self.device)
410
- text_mask = text_mask.to(self.device)
411
- text_encoder_hidden_states, text_mask, text_prompt_embeds = \
412
- pad_or_tunc_tolen(text_encoder_hidden_states, text_mask, text_prompt_embeds)
413
- text_encoder_hidden_states = text_encoder_hidden_states.permute(0,2,1).contiguous()
414
- return text_encoder_hidden_states, text_mask
415
-
416
- def extract_energy_bar(self, input_audios):
417
- if(input_audios.shape[-1] % self.num_samples_perseg > 0):
418
- energy_bar = input_audios[:,:-1 * (input_audios.shape[-1] % self.num_samples_perseg)].reshape(input_audios.shape[0],-1,self.num_samples_perseg)
419
- else:
420
- energy_bar = input_audios.reshape(input_audios.shape[0],-1,self.num_samples_perseg)
421
- energy_bar = (energy_bar.pow(2.0).mean(-1).sqrt() + 1e-6).log10() * 20 # B T
422
- energy_bar = (energy_bar / 2.0 + 16).clamp(0,16).int()
423
- energy_embedding = self.energy_embedding(energy_bar)
424
- energy_embedding = energy_embedding.view(energy_embedding.shape[0], energy_embedding.shape[1] // 2, 2, 32).reshape(energy_embedding.shape[0], energy_embedding.shape[1] // 2, 64).permute(0,2,1) # b 128 t
425
- return energy_embedding
426
-
427
- def forward(self, input_audios, lyric, latents, latent_masks, validation_mode=False, \
428
- additional_feats = ['spk', 'lyric'], \
429
- train_rvq=True, train_ssl=False,layer=5):
430
- if not hasattr(self,"device"):
431
- self.device = input_audios.device
432
- if not hasattr(self,"dtype"):
433
- self.dtype = input_audios.dtype
434
- device = self.device
435
- input_audio_0 = input_audios[:,0,:]
436
- input_audio_1 = input_audios[:,1,:]
437
- input_audio_0 = self.preprocess_audio(input_audio_0)
438
- input_audio_1 = self.preprocess_audio(input_audio_1)
439
- input_audios_wav2vec = (input_audio_0 + input_audio_1) / 2.0
440
- # energy_embedding = self.extract_energy_bar(input_audios)
441
- # print("energy_embedding.shape:",energy_embedding.shape)
442
- # with autocast(enabled=False):
443
- if(train_ssl):
444
- self.wav2vec.train()
445
- wav2vec_embeds = self.extract_wav2vec_embeds(input_audios)
446
- self.clap_embd_extractor.train()
447
- prompt_embeds = self.extract_mert_embeds(input_audios)
448
- if('spk' in additional_feats):
449
- self.xvecmodel.train()
450
- spk_embeds = self.extract_spk_embeds(input_audios).repeat(1,1,prompt_embeds.shape[-1]//2,1)
451
- else:
452
- with torch.no_grad():
453
- with autocast(enabled=False):
454
- bestrq_emb = self.extract_bestrq_embeds(input_audio_0,input_audio_1,layer)
455
- # mert_emb = self.extract_mert_embeds(input_audios_mert)
456
-
457
- wav2vec_embeds = self.extract_wav2vec_embeds(input_audios_wav2vec,bestrq_emb.shape[2])
458
-
459
- bestrq_emb = bestrq_emb.detach()
460
- if('lyric' in additional_feats):
461
- text_encoder_hidden_states, text_mask = self.extract_lyric_feats(lyric)
462
- else:
463
- text_encoder_hidden_states, text_mask = None, None
464
-
465
- # prompt_embeds_13 = torch.cat([mert_emb_13, energy_embedding], 1)
466
- # print("prompt_embes.shape:",prompt_embeds.shape)
467
- #prompt_embes.shape: torch.Size([3, 1088, 896])
468
- # print("wav2vec_embeds.shape:",wav2vec_embeds.shape)
469
- #wav2vec_embeds.shape:torch.Size([3, 1024, 896])
470
- if(train_rvq):
471
- quantized_bestrq_emb, _, _, commitment_loss_bestrq_emb, codebook_loss_bestrq_emb,_ = self.rvq_bestrq_emb(bestrq_emb) # b,d,t
472
- else:
473
- bestrq_emb = bestrq_emb.float()
474
- self.rvq_bestrq_emb.eval()
475
- # with autocast(enabled=False):
476
- quantized_bestrq_emb, _, _, commitment_loss_bestrq_emb, codebook_loss_bestrq_emb,_ = self.rvq_bestrq_emb(bestrq_emb) # b,d,t
477
- commitment_loss_bestrq_emb = commitment_loss_bestrq_emb.detach()
478
- codebook_loss_bestrq_emb = codebook_loss_bestrq_emb.detach()
479
- quantized_bestrq_emb = quantized_bestrq_emb.detach()
480
-
481
- commitment_loss = commitment_loss_bestrq_emb
482
- codebook_loss = codebook_loss_bestrq_emb
483
-
484
-
485
- alpha=1
486
- quantized_bestrq_emb = quantized_bestrq_emb * alpha + bestrq_emb * (1-alpha)
487
-
488
- # print("quantized_bestrq_emb.shape:",quantized_bestrq_emb.shape)
489
- # print("latent_masks.shape:",latent_masks.shape)
490
- # quantized_bestrq_emb = torch.nn.functional.interpolate(quantized_bestrq_emb, size=(int(quantized_bestrq_emb.shape[-1]/999*937),), mode='linear', align_corners=True)
491
-
492
-
493
-
494
- scenario = np.random.choice(['start_seg', 'other_seg'])
495
- if(scenario == 'other_seg'):
496
- for binx in range(input_audios.shape[0]):
497
- # latent_masks[binx,0:64] = 1
498
- latent_masks[binx,0:random.randint(64,128)] = 1
499
- quantized_bestrq_emb = quantized_bestrq_emb.permute(0,2,1).contiguous()
500
- # print("quantized_bestrq_emb.shape:",quantized_bestrq_emb.shape)
501
- # print("quantized_bestrq_emb1.shape:",quantized_bestrq_emb.shape)
502
- # print("latent_masks.shape:",latent_masks.shape)
503
- quantized_bestrq_emb = (latent_masks > 0.5).unsqueeze(-1) * quantized_bestrq_emb \
504
- + (latent_masks < 0.5).unsqueeze(-1) * self.zero_cond_embedding1.reshape(1,1,1024)
505
-
506
-
507
-
508
-
509
- if self.uncondition:
510
- mask_indices = [k for k in range(quantized_bestrq_emb.shape[0]) if random.random() < 0.1]
511
- if len(mask_indices) > 0:
512
- quantized_bestrq_emb[mask_indices] = 0
513
- # print("latents.shape:",latents.shape)
514
- latents = latents.permute(0,2,1).contiguous()
515
- latents = self.normfeat.project_sample(latents)
516
- latents = latents.permute(0,2,1).contiguous()
517
- incontext_latents = latents * ((latent_masks > 0.5) * (latent_masks < 1.5)).unsqueeze(-1).float()
518
- attention_mask=(latent_masks > 0.5)
519
- B, L = attention_mask.size()
520
- attention_mask = attention_mask.view(B, 1, L)
521
- attention_mask = attention_mask * attention_mask.transpose(-1, -2)
522
- attention_mask = attention_mask.unsqueeze(1)
523
- # print("incontext_latents.shape:",incontext_latents.shape)
524
- # print("quantized_bestrq_emb.shape:",quantized_bestrq_emb.shape)
525
- latent_mask_input = self.mask_emb(latent_masks)
526
- #64+48+64+1024
527
- loss,loss_re, loss_cos = self.cfm_wrapper.compute_loss(latents, [latent_mask_input,incontext_latents, quantized_bestrq_emb], latent_masks,attention_mask,wav2vec_embeds, validation_mode=validation_mode)
528
- return loss,loss_re, loss_cos, commitment_loss.mean(), codebook_loss.mean()
529
-
530
- def init_device_dtype(self, device, dtype):
531
- self.device = device
532
- self.dtype = dtype
533
-
534
- @torch.no_grad()
535
- def fetch_codes(self, input_audios, additional_feats,layer):
536
- input_audio_0 = input_audios[[0],:]
537
- input_audio_1 = input_audios[[1],:]
538
- input_audio_0 = self.preprocess_audio(input_audio_0)
539
- input_audio_1 = self.preprocess_audio(input_audio_1)
540
-
541
- self.bestrq.eval()
542
-
543
- # bestrq_middle,bestrq_last = self.extract_bestrq_embeds(input_audios)
544
- # bestrq_middle = bestrq_middle.detach()
545
- # bestrq_last = bestrq_last.detach()
546
- bestrq_emb = self.extract_bestrq_embeds(input_audio_0,input_audio_1,layer)
547
- bestrq_emb = bestrq_emb.detach()
548
-
549
- # self.rvq_bestrq_middle.eval()
550
- # quantized_bestrq_middle, codes_bestrq_middle, *_ = self.rvq_bestrq_middle(bestrq_middle) # b,d,t
551
- # self.rvq_bestrq_last.eval()
552
- # quantized_bestrq_last, codes_bestrq_last, *_ = self.rvq_bestrq_last(bestrq_last) # b,d,t
553
-
554
- self.rvq_bestrq_emb.eval()
555
- quantized_bestrq_emb, codes_bestrq_emb, *_ = self.rvq_bestrq_emb(bestrq_emb) # b,d,t
556
-
557
-
558
- if('spk' in additional_feats):
559
- self.xvecmodel.eval()
560
- spk_embeds = self.extract_spk_embeds(input_audios)
561
- else:
562
- spk_embeds = None
563
-
564
- # return [codes_prompt, codes_wav2vec], [prompt_embeds, wav2vec_embeds], spk_embeds
565
- # return [codes_prompt_7, codes_prompt_13, codes_prompt_20, codes_wav2vec_half, codes_wav2vec_last], [prompt_embeds_7, prompt_embeds_13, prompt_embeds_20, wav2vec_embeds_half, wav2vec_embeds_last], spk_embeds
566
- # return [codes_bestrq_middle, codes_bestrq_last], [bestrq_middle, bestrq_last], spk_embeds
567
- return [codes_bestrq_emb], [bestrq_emb], spk_embeds
568
- # return [codes_prompt_13, codes_wav2vec_last], [prompt_embeds_13, wav2vec_embeds_last], spk_embeds
569
-
570
-
571
- @torch.no_grad()
572
- def fetch_codes_batch(self, input_audios, additional_feats,layer):
573
- input_audio_0 = input_audios[:,0,:]
574
- input_audio_1 = input_audios[:,1,:]
575
- input_audio_0 = self.preprocess_audio(input_audio_0)
576
- input_audio_1 = self.preprocess_audio(input_audio_1)
577
-
578
- self.bestrq.eval()
579
-
580
- # bestrq_middle,bestrq_last = self.extract_bestrq_embeds(input_audios)
581
- # bestrq_middle = bestrq_middle.detach()
582
- # bestrq_last = bestrq_last.detach()
583
- bestrq_emb = self.extract_bestrq_embeds(input_audio_0,input_audio_1,layer)
584
- bestrq_emb = bestrq_emb.detach()
585
-
586
- # self.rvq_bestrq_middle.eval()
587
- # quantized_bestrq_middle, codes_bestrq_middle, *_ = self.rvq_bestrq_middle(bestrq_middle) # b,d,t
588
- # self.rvq_bestrq_last.eval()
589
- # quantized_bestrq_last, codes_bestrq_last, *_ = self.rvq_bestrq_last(bestrq_last) # b,d,t
590
-
591
- self.rvq_bestrq_emb.eval()
592
- quantized_bestrq_emb, codes_bestrq_emb, *_ = self.rvq_bestrq_emb(bestrq_emb) # b,d,t
593
-
594
-
595
- if('spk' in additional_feats):
596
- self.xvecmodel.eval()
597
- spk_embeds = self.extract_spk_embeds(input_audios)
598
- else:
599
- spk_embeds = None
600
-
601
- # return [codes_prompt, codes_wav2vec], [prompt_embeds, wav2vec_embeds], spk_embeds
602
- # return [codes_prompt_7, codes_prompt_13, codes_prompt_20, codes_wav2vec_half, codes_wav2vec_last], [prompt_embeds_7, prompt_embeds_13, prompt_embeds_20, wav2vec_embeds_half, wav2vec_embeds_last], spk_embeds
603
- # return [codes_bestrq_middle, codes_bestrq_last], [bestrq_middle, bestrq_last], spk_embeds
604
- return [codes_bestrq_emb], [bestrq_emb], spk_embeds
605
-
606
- @torch.no_grad()
607
- def inference_codes(self, codes, spk_embeds, true_latents, latent_length, additional_feats, incontext_length=127,
608
- guidance_scale=2, num_steps=20,
609
- disable_progress=True, scenario='start_seg'):
610
- classifier_free_guidance = guidance_scale > 1.0
611
- device = self.device
612
- dtype = self.dtype
613
- # codes_bestrq_middle, codes_bestrq_last = codes
614
- codes_bestrq_emb = codes[0]
615
-
616
-
617
- batch_size = codes_bestrq_emb.shape[0]
618
-
619
-
620
- quantized_bestrq_emb,_,_=self.rvq_bestrq_emb.from_codes(codes_bestrq_emb)
621
- # quantized_bestrq_emb = torch.nn.functional.interpolate(quantized_bestrq_emb, size=(int(quantized_bestrq_emb.shape[-1]/999*937),), mode='linear', align_corners=True)
622
- quantized_bestrq_emb = quantized_bestrq_emb.permute(0,2,1).contiguous()
623
- print("quantized_bestrq_emb.shape:",quantized_bestrq_emb.shape)
624
- # quantized_bestrq_emb = torch.nn.functional.interpolate(quantized_bestrq_emb, size=(int(quantized_bestrq_emb.shape[-1]/999*937),), mode='linear', align_corners=True)
625
-
626
-
627
-
628
-
629
- if('spk' in additional_feats):
630
- spk_embeds = spk_embeds.repeat(1,1,quantized_bestrq_emb.shape[-2],1).detach()
631
-
632
- num_frames = quantized_bestrq_emb.shape[1]
633
-
634
- num_channels_latents = self.num_channels
635
- shape = (batch_size, num_frames, 64)
636
- latents = randn_tensor(shape, generator=None, device=device, dtype=dtype)
637
-
638
-
639
-
640
- latent_masks = torch.zeros(latents.shape[0], latents.shape[1], dtype=torch.int64, device=latents.device)
641
- latent_masks[:,0:latent_length] = 2
642
- if(scenario=='other_seg'):
643
- latent_masks[:,0:incontext_length] = 1
644
-
645
-
646
-
647
- quantized_bestrq_emb = (latent_masks > 0.5).unsqueeze(-1) * quantized_bestrq_emb \
648
- + (latent_masks < 0.5).unsqueeze(-1) * self.zero_cond_embedding1.reshape(1,1,1024)
649
- true_latents = true_latents.permute(0,2,1).contiguous()
650
- true_latents = self.normfeat.project_sample(true_latents)
651
- true_latents = true_latents.permute(0,2,1).contiguous()
652
- incontext_latents = true_latents * ((latent_masks > 0.5) * (latent_masks < 1.5)).unsqueeze(-1).float()
653
- incontext_length = ((latent_masks > 0.5) * (latent_masks < 1.5)).sum(-1)[0]
654
-
655
-
656
- attention_mask=(latent_masks > 0.5)
657
- B, L = attention_mask.size()
658
- attention_mask = attention_mask.view(B, 1, L)
659
- attention_mask = attention_mask * attention_mask.transpose(-1, -2)
660
- attention_mask = attention_mask.unsqueeze(1)
661
- latent_mask_input = self.mask_emb(latent_masks)
662
-
663
- if('spk' in additional_feats):
664
- # additional_model_input = torch.cat([quantized_bestrq_middle, quantized_bestrq_last, spk_embeds],1)
665
- additional_model_input = torch.cat([quantized_bestrq_emb, spk_embeds],1)
666
- else:
667
- # additional_model_input = torch.cat([quantized_bestrq_middle, quantized_bestrq_last],1)
668
- additional_model_input = torch.cat([quantized_bestrq_emb],1)
669
-
670
- temperature = 1.0
671
- t_span = torch.linspace(0, 1, num_steps + 1, device=quantized_bestrq_emb.device)
672
- latents = self.cfm_wrapper.solve_euler(latents * temperature, latent_mask_input,incontext_latents, incontext_length, t_span, additional_model_input,attention_mask, guidance_scale)
673
-
674
- latents[:,0:incontext_length,:] = incontext_latents[:,0:incontext_length,:]
675
- latents = latents.permute(0,2,1).contiguous()
676
- latents = self.normfeat.return_sample(latents)
677
- # latents = latents.permute(0,2,1).contiguous()
678
- return latents
679
-
680
- @torch.no_grad()
681
- def inference(self, input_audios, lyric, true_latents, latent_length, additional_feats, guidance_scale=2, num_steps=20,
682
- disable_progress=True,layer=5,scenario='start_seg'):
683
- codes, embeds, spk_embeds = self.fetch_codes(input_audios, additional_feats,layer)
684
-
685
- latents = self.inference_codes(codes, spk_embeds, true_latents, latent_length, additional_feats, \
686
- guidance_scale=guidance_scale, num_steps=num_steps, \
687
- disable_progress=disable_progress,scenario=scenario)
688
- return latents
689
-
690
- @torch.no_grad()
691
- def inference_rtf(self, input_audios, lyric, true_latents, latent_length, additional_feats, guidance_scale=2, num_steps=20,
692
- disable_progress=True,layer=5,scenario='start_seg'):
693
- codes, embeds, spk_embeds = self.fetch_codes(input_audios, additional_feats,layer)
694
- import time
695
- start = time.time()
696
- latents = self.inference_codes(codes, spk_embeds, true_latents, latent_length, additional_feats, \
697
- guidance_scale=guidance_scale, num_steps=num_steps, \
698
- disable_progress=disable_progress,scenario=scenario)
699
- return latents,time.time()-start
700
-
701
- def prepare_latents(self, batch_size, num_frames, num_channels_latents, dtype, device):
702
- divisor = 4
703
- shape = (batch_size, num_channels_latents, num_frames, 32)
704
- if(num_frames%divisor>0):
705
- num_frames = round(num_frames/float(divisor))*divisor
706
- shape = (batch_size, num_channels_latents, num_frames, 32)
707
- latents = randn_tensor(shape, generator=None, device=device, dtype=dtype)
708
- return latents
709
-
710
-
 
1
+ import yaml
2
+ import random
3
+ import inspect
4
+ import numpy as np
5
+ from tqdm import tqdm
6
+ import typing as tp
7
+ from abc import ABC
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ import torchaudio
13
+
14
+ from tools.torch_tools import wav_to_fbank
15
+
16
+ from diffusers.utils.torch_utils import randn_tensor
17
+ from transformers import HubertModel
18
+ from libs.rvq.descript_quantize3 import ResidualVectorQuantize
19
+
20
+ from models_gpt.models.gpt2_rope2_time_new_correct_mask_noncasual_reflow import GPT2Model
21
+ from models_gpt.models.gpt2_config import GPT2Config
22
+
23
+ from torch.cuda.amp import autocast
24
+
25
+
26
+ from our_MERT_BESTRQ.test import load_model
27
+
28
+ class HubertModelWithFinalProj(HubertModel):
29
+ def __init__(self, config):
30
+ super().__init__(config)
31
+
32
+ # The final projection layer is only used for backward compatibility.
33
+ # Following https://github.com/auspicious3000/contentvec/issues/6
34
+ # Remove this layer is necessary to achieve the desired outcome.
35
+ print("hidden_size:",config.hidden_size)
36
+ print("classifier_proj_size:",config.classifier_proj_size)
37
+ self.final_proj = nn.Linear(config.hidden_size, config.classifier_proj_size)
38
+
39
+
40
+ class SampleProcessor(torch.nn.Module):
41
+ def project_sample(self, x: torch.Tensor):
42
+ """Project the original sample to the 'space' where the diffusion will happen."""
43
+ """Project back from diffusion space to the actual sample space."""
44
+ return z
45
+
46
+ class Feature1DProcessor(SampleProcessor):
47
+ def __init__(self, dim: int = 100, power_std = 1., \
48
+ num_samples: int = 100_000, cal_num_frames: int = 600):
49
+ super().__init__()
50
+
51
+ self.num_samples = num_samples
52
+ self.dim = dim
53
+ self.power_std = power_std
54
+ self.cal_num_frames = cal_num_frames
55
+ self.register_buffer('counts', torch.zeros(1))
56
+ self.register_buffer('sum_x', torch.zeros(dim))
57
+ self.register_buffer('sum_x2', torch.zeros(dim))
58
+ self.register_buffer('sum_target_x2', torch.zeros(dim))
59
+ self.counts: torch.Tensor
60
+ self.sum_x: torch.Tensor
61
+ self.sum_x2: torch.Tensor
62
+
63
+ @property
64
+ def mean(self):
65
+ mean = self.sum_x / self.counts
66
+ if(self.counts < 10):
67
+ mean = torch.zeros_like(mean)
68
+ return mean
69
+
70
+ @property
71
+ def std(self):
72
+ std = (self.sum_x2 / self.counts - self.mean**2).clamp(min=0).sqrt()
73
+ if(self.counts < 10):
74
+ std = torch.ones_like(std)
75
+ return std
76
+
77
+ @property
78
+ def target_std(self):
79
+ return 1
80
+
81
+ def project_sample(self, x: torch.Tensor):
82
+ assert x.dim() == 3
83
+ if self.counts.item() < self.num_samples:
84
+ self.counts += len(x)
85
+ self.sum_x += x[:,:,0:self.cal_num_frames].mean(dim=(2,)).sum(dim=0)
86
+ self.sum_x2 += x[:,:,0:self.cal_num_frames].pow(2).mean(dim=(2,)).sum(dim=0)
87
+ rescale = (self.target_std / self.std.clamp(min=1e-12)) ** self.power_std # same output size
88
+ x = (x - self.mean.view(1, -1, 1)) * rescale.view(1, -1, 1)
89
+ return x
90
+
91
+ def return_sample(self, x: torch.Tensor):
92
+ assert x.dim() == 3
93
+ rescale = (self.std / self.target_std) ** self.power_std
94
+ # print(rescale, self.mean)
95
+ x = x * rescale.view(1, -1, 1) + self.mean.view(1, -1, 1)
96
+ return x
97
+
98
+ def pad_or_tunc_tolen(prior_text_encoder_hidden_states, prior_text_mask, prior_prompt_embeds, len_size=77):
99
+ if(prior_text_encoder_hidden_states.shape[1]<len_size):
100
+ prior_text_encoder_hidden_states = torch.cat([prior_text_encoder_hidden_states, \
101
+ torch.zeros(prior_text_mask.shape[0], len_size-prior_text_mask.shape[1], \
102
+ prior_text_encoder_hidden_states.shape[2], device=prior_text_mask.device, \
103
+ dtype=prior_text_encoder_hidden_states.dtype)],1)
104
+ prior_text_mask = torch.cat([prior_text_mask, torch.zeros(prior_text_mask.shape[0], len_size-prior_text_mask.shape[1], device=prior_text_mask.device, dtype=prior_text_mask.dtype)],1)
105
+ else:
106
+ prior_text_encoder_hidden_states = prior_text_encoder_hidden_states[:,0:len_size]
107
+ prior_text_mask = prior_text_mask[:,0:len_size]
108
+ prior_text_encoder_hidden_states = prior_text_encoder_hidden_states.permute(0,2,1).contiguous()
109
+ return prior_text_encoder_hidden_states, prior_text_mask, prior_prompt_embeds
110
+
111
+ class BASECFM(torch.nn.Module, ABC):
112
+ def __init__(
113
+ self,
114
+ estimator,
115
+ mlp,
116
+ ssl_layer
117
+ ):
118
+ super().__init__()
119
+ self.sigma_min = 1e-4
120
+
121
+ self.estimator = estimator
122
+ self.mlp = mlp
123
+ self.ssl_layer = ssl_layer
124
+
125
+ @torch.inference_mode()
126
+ def forward(self, mu, n_timesteps, temperature=1.0):
127
+ """Forward diffusion
128
+
129
+ Args:
130
+ mu (torch.Tensor): output of encoder
131
+ shape: (batch_size, n_channels, mel_timesteps, n_feats)
132
+ n_timesteps (int): number of diffusion steps
133
+ temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
134
+
135
+ Returns:
136
+ sample: generated mel-spectrogram
137
+ shape: (batch_size, n_channels, mel_timesteps, n_feats)
138
+ """
139
+ z = torch.randn_like(mu) * temperature
140
+ t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device)
141
+ return self.solve_euler(z, t_span=t_span)
142
+
143
+ def solve_euler(self, x, latent_mask_input,incontext_x, incontext_length, t_span, mu,attention_mask, guidance_scale):
144
+ """
145
+ Fixed euler solver for ODEs.
146
+ Args:
147
+ x (torch.Tensor): random noise
148
+ t_span (torch.Tensor): n_timesteps interpolated
149
+ shape: (n_timesteps + 1,)
150
+ mu (torch.Tensor): output of encoder
151
+ shape: (batch_size, n_channels, mel_timesteps, n_feats)
152
+ """
153
+ t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
154
+ noise = x.clone()
155
+
156
+ # I am storing this because I can later plot it by putting a debugger here and saving it to a file
157
+ # Or in future might add like a return_all_steps flag
158
+ sol = []
159
+
160
+ for step in tqdm(range(1, len(t_span))):
161
+ # print("incontext_x.shape:",incontext_x.shape)
162
+ # print("noise.shape:",noise.shape)
163
+ # print("t.shape:",t.shape)
164
+ x[:,0:incontext_length,:] = (1 - (1 - self.sigma_min) * t) * noise[:,0:incontext_length,:] + t * incontext_x[:,0:incontext_length,:]
165
+ if(guidance_scale > 1.0):
166
+
167
+ model_input = torch.cat([ \
168
+ torch.cat([latent_mask_input, latent_mask_input], 0), \
169
+ torch.cat([incontext_x, incontext_x], 0), \
170
+ torch.cat([torch.zeros_like(mu), mu], 0), \
171
+ torch.cat([x, x], 0), \
172
+ ], 2)
173
+ timestep=t.unsqueeze(-1).repeat(2)
174
+
175
+ dphi_dt = self.estimator(inputs_embeds=model_input, attention_mask=attention_mask,time_step=timestep).last_hidden_state
176
+ dphi_dt_uncond, dhpi_dt_cond = dphi_dt.chunk(2,0)
177
+ dphi_dt = dphi_dt_uncond + guidance_scale * (dhpi_dt_cond - dphi_dt_uncond)
178
+ else:
179
+ model_input = torch.cat([latent_mask_input, incontext_x, mu, x], 2)
180
+ timestep=t.unsqueeze(-1)
181
+ dphi_dt = self.estimator(inputs_embeds=model_input, attention_mask=attention_mask,time_step=timestep).last_hidden_state
182
+
183
+ dphi_dt = dphi_dt[: ,:, -x.shape[2]:]
184
+ # print("dphi_dt.shape:",dphi_dt.shape)
185
+ # print("x.shape:",x.shape)
186
+
187
+ x = x + dt * dphi_dt
188
+ t = t + dt
189
+ sol.append(x)
190
+ if step < len(t_span) - 1:
191
+ dt = t_span[step + 1] - t
192
+
193
+ return sol[-1]
194
+
195
+ def projection_loss(self,hidden_proj, bestrq_emb):
196
+ bsz = hidden_proj.shape[0]
197
+
198
+ hidden_proj_normalized = F.normalize(hidden_proj, dim=-1)
199
+ bestrq_emb_normalized = F.normalize(bestrq_emb, dim=-1)
200
+
201
+ proj_loss = -(hidden_proj_normalized * bestrq_emb_normalized).sum(dim=-1)
202
+ proj_loss = 1+proj_loss.mean()
203
+
204
+ return proj_loss
205
+
206
+ def compute_loss(self, x1, mu, latent_masks,attention_mask,wav2vec_embeds, validation_mode=False):
207
+ """Computes diffusion loss
208
+
209
+ Args:
210
+ x1 (torch.Tensor): Target
211
+ shape: (batch_size, n_channels, mel_timesteps, n_feats)
212
+ mu (torch.Tensor): output of encoder
213
+ shape: (batch_size, n_channels, mel_timesteps, n_feats)
214
+
215
+ Returns:
216
+ loss: conditional flow matching loss
217
+ y: conditional flow
218
+ shape: (batch_size, n_channels, mel_timesteps, n_feats)
219
+ """
220
+ b = mu[0].shape[0]
221
+ len_x = x1.shape[2]
222
+ # random timestep
223
+ if(validation_mode):
224
+ t = torch.ones([b, 1, 1], device=mu[0].device, dtype=mu[0].dtype) * 0.5
225
+ else:
226
+ t = torch.rand([b, 1, 1], device=mu[0].device, dtype=mu[0].dtype)
227
+ # sample noise p(x_0)
228
+ z = torch.randn_like(x1)
229
+
230
+ y = (1 - (1 - self.sigma_min) * t) * z + t * x1
231
+ u = x1 - (1 - self.sigma_min) * z
232
+ # print("y.shape:",y.shape)
233
+ #self.unet(inputs_embeds=model_input, attention_mask=attention_mask,encoder_hidden_states=text_embedding,encoder_attention_mask=txt_attn_mask,time_step=timesteps).last_hidden_state
234
+ model_input = torch.cat([*mu,y], 2)
235
+ t=t.squeeze(-1).squeeze(-1)
236
+ # print("model_input.shape:",model_input.shape)
237
+ # print("attention_mask.shape:",attention_mask.shape)
238
+ out = self.estimator(inputs_embeds=model_input, attention_mask=attention_mask,time_step=t,output_hidden_states=True)
239
+ hidden_layer = out.hidden_states[self.ssl_layer]
240
+ hidden_proj = self.mlp(hidden_layer)
241
+ # print("hidden_proj.shape:",hidden_proj.shape)
242
+ # print("mert_emb.shape:",mert_emb.shape)
243
+ # exit()
244
+
245
+
246
+ out = out.last_hidden_state
247
+
248
+ out=out[:,:,-len_x:]
249
+ # out=self.proj_out(out)
250
+
251
+ weight = (latent_masks > 1.5).unsqueeze(-1).repeat(1, 1, out.shape[-1]).float() + (latent_masks < 0.5).unsqueeze(-1).repeat(1, 1, out.shape[-1]).float() * 0.01
252
+ # print("out.shape",out.shape)
253
+ # print("u.shape",u.shape)
254
+ loss_re = F.mse_loss(out * weight, u * weight, reduction="sum") / weight.sum()
255
+ # print("hidden_proj.shape:",hidden_proj.shape)
256
+ # print("wav2vec_embeds.shape:",wav2vec_embeds.shape)
257
+ loss_cos = self.projection_loss(hidden_proj, wav2vec_embeds)
258
+ loss = loss_re + loss_cos * 0.5
259
+ # print("loss_cos:",loss_cos,loss_cos.device)
260
+ print("loss:",loss,loss.device)
261
+ # exit()
262
+ return loss, loss_re, loss_cos
263
+
264
+ class PromptCondAudioDiffusion(nn.Module):
265
+ def __init__(
266
+ self,
267
+ num_channels,
268
+ unet_model_name=None,
269
+ unet_model_config_path=None,
270
+ snr_gamma=None,
271
+ hubert_layer=None,
272
+ ssl_layer=None,
273
+ uncondition=True,
274
+ out_paint=False,
275
+ ):
276
+ super().__init__()
277
+
278
+ assert unet_model_name is not None or unet_model_config_path is not None, "Either UNet pretrain model name or a config file path is required"
279
+
280
+ self.unet_model_name = unet_model_name
281
+ self.unet_model_config_path = unet_model_config_path
282
+ self.snr_gamma = snr_gamma
283
+ self.uncondition = uncondition
284
+ self.num_channels = num_channels
285
+ self.hubert_layer = hubert_layer
286
+ self.ssl_layer = ssl_layer
287
+
288
+ # https://huggingface.co/docs/diffusers/v0.14.0/en/api/schedulers/overview
289
+ self.normfeat = Feature1DProcessor(dim=64)
290
+
291
+ self.sample_rate = 48000
292
+ self.num_samples_perseg = self.sample_rate * 20 // 1000
293
+ self.rsp48toclap = torchaudio.transforms.Resample(48000, 24000)
294
+ self.rsq48towav2vec = torchaudio.transforms.Resample(48000, 16000)
295
+ # self.wav2vec = Wav2Vec2BertModel.from_pretrained("facebook/w2v-bert-2.0", trust_remote_code=True)
296
+ # self.wav2vec_processor = AutoFeatureExtractor.from_pretrained("facebook/w2v-bert-2.0", trust_remote_code=True)
297
+ self.bestrq = load_model(
298
+ model_dir='codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq',
299
+ checkpoint_dir='ckpt/encode-s12k.pt',
300
+ )
301
+ self.rsq48tobestrq = torchaudio.transforms.Resample(48000, 24000)
302
+ self.rsq48tohubert = torchaudio.transforms.Resample(48000, 16000)
303
+ for v in self.bestrq.parameters():v.requires_grad = False
304
+ self.rvq_bestrq_emb = ResidualVectorQuantize(input_dim = 1024, n_codebooks = 1, codebook_size = 16_384, codebook_dim = 32, quantizer_dropout = 0.0, stale_tolerance=200)
305
+ for v in self.rvq_bestrq_emb.parameters():v.requires_grad = False
306
+ self.hubert = HubertModelWithFinalProj.from_pretrained("ckpt/models--lengyue233--content-vec-best/snapshots/c0b9ba13db21beaa4053faae94c102ebe326fd68")
307
+ for v in self.hubert.parameters():v.requires_grad = False
308
+ self.zero_cond_embedding1 = nn.Parameter(torch.randn(32*32,))
309
+ # self.xvecmodel = XVECModel()
310
+ config = GPT2Config(n_positions=1000,n_layer=39,n_head=30,n_embd=1200)
311
+ unet = GPT2Model(config)
312
+ mlp = nn.Sequential(
313
+ nn.Linear(1200, 1024),
314
+ nn.SiLU(),
315
+ nn.Linear(1024, 1024),
316
+ nn.SiLU(),
317
+ nn.Linear(1024, 768)
318
+ )
319
+ self.set_from = "random"
320
+ self.cfm_wrapper = BASECFM(unet, mlp,self.ssl_layer)
321
+ self.mask_emb = torch.nn.Embedding(3, 48)
322
+ print("Transformer initialized from pretrain.")
323
+ torch.cuda.empty_cache()
324
+ # self.unet.set_attn_processor(AttnProcessor2_0())
325
+ # self.unet.set_use_memory_efficient_attention_xformers(True)
326
+
327
+ # self.start_embedding = nn.Parameter(torch.randn(1,1024))
328
+ # self.end_embedding = nn.Parameter(torch.randn(1,1024))
329
+
330
+ def compute_snr(self, timesteps):
331
+ """
332
+ Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
333
+ """
334
+ alphas_cumprod = self.noise_scheduler.alphas_cumprod
335
+ sqrt_alphas_cumprod = alphas_cumprod**0.5
336
+ sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
337
+
338
+ # Expand the tensors.
339
+ # Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026
340
+ sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
341
+ while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape):
342
+ sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
343
+ alpha = sqrt_alphas_cumprod.expand(timesteps.shape)
344
+
345
+ sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
346
+ while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape):
347
+ sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None]
348
+ sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape)
349
+
350
+ # Compute SNR.
351
+ snr = (alpha / sigma) ** 2
352
+ return snr
353
+
354
+ def preprocess_audio(self, input_audios, threshold=0.9):
355
+ assert len(input_audios.shape) == 2, input_audios.shape
356
+ norm_value = torch.ones_like(input_audios[:,0])
357
+ max_volume = input_audios.abs().max(dim=-1)[0]
358
+ norm_value[max_volume>threshold] = max_volume[max_volume>threshold] / threshold
359
+ return input_audios/norm_value.unsqueeze(-1)
360
+
361
+ def extract_wav2vec_embeds(self, input_audios,output_len):
362
+ wav2vec_stride = 2
363
+
364
+ wav2vec_embeds = self.hubert(self.rsq48tohubert(input_audios), output_hidden_states=True).hidden_states # 1, 4096, 1024
365
+ # print(wav2vec_embeds)
366
+ # print("audio.shape:",input_audios.shape)
367
+ wav2vec_embeds_last=wav2vec_embeds[self.hubert_layer]
368
+ # print("wav2vec_embeds_last.shape:",wav2vec_embeds_last.shape)
369
+ wav2vec_embeds_last=torch.nn.functional.interpolate(wav2vec_embeds_last.permute(0, 2, 1), size=output_len, mode='linear', align_corners=False).permute(0, 2, 1)
370
+ return wav2vec_embeds_last
371
+
372
+ def extract_mert_embeds(self, input_audios):
373
+ prompt_stride = 3
374
+ inputs = self.clap_embd_extractor.mulan.audio.processor(self.rsp48toclap(input_audios), sampling_rate=self.clap_embd_extractor.mulan.audio.sr, return_tensors="pt")
375
+ input_values = inputs['input_values'].squeeze(0).to(input_audios.device, dtype = input_audios.dtype)
376
+ prompt_embeds = self.clap_embd_extractor.mulan.audio.model(input_values, output_hidden_states=True).hidden_states # batch_size, Time steps, 1024
377
+ mert_emb= prompt_embeds[-1]
378
+ mert_emb = torch.nn.functional.interpolate(mert_emb.permute(0, 2, 1), size=500, mode='linear', align_corners=False).permute(0, 2, 1)
379
+
380
+ return mert_emb
381
+
382
+ def extract_bestrq_embeds(self, input_audio_0,input_audio_1,layer):
383
+ self.bestrq.eval()
384
+ # print("audio shape:",input_audio_0.shape)
385
+ input_wav_mean = (input_audio_0 + input_audio_1) / 2.0
386
+ # print("input_wav_mean.shape:",input_wav_mean.shape)
387
+ # input_wav_mean = torch.randn(2,1720320*2).to(input_audio_0.device)
388
+ input_wav_mean = self.bestrq(self.rsq48tobestrq(input_wav_mean), features_only = True)
389
+ layer_results = input_wav_mean['layer_results']
390
+ # print("layer_results.shape:",layer_results[layer].shape)
391
+ bestrq_emb = layer_results[layer]
392
+ bestrq_emb = bestrq_emb.permute(0,2,1).contiguous()
393
+ #[b,t,1024] t=t/960
394
+ #35.84s->batch,896,1024
395
+ return bestrq_emb
396
+
397
+
398
+ def extract_spk_embeds(self, input_audios):
399
+ spk_embeds = self.xvecmodel(self.rsq48towav2vec(input_audios))
400
+ spk_embeds = self.spk_linear(spk_embeds).reshape(spk_embeds.shape[0], 16, 1, 32)
401
+ return spk_embeds
402
+
403
+ def extract_lyric_feats(self, lyric):
404
+ with torch.no_grad():
405
+ try:
406
+ text_encoder_hidden_states, text_mask, text_prompt_embeds = self.clap_embd_extractor(texts = lyric, return_one=False)
407
+ except:
408
+ text_encoder_hidden_states, text_mask, text_prompt_embeds = self.clap_embd_extractor(texts = [""] * len(lyric), return_one=False)
409
+ text_encoder_hidden_states = text_encoder_hidden_states.to(self.device)
410
+ text_mask = text_mask.to(self.device)
411
+ text_encoder_hidden_states, text_mask, text_prompt_embeds = \
412
+ pad_or_tunc_tolen(text_encoder_hidden_states, text_mask, text_prompt_embeds)
413
+ text_encoder_hidden_states = text_encoder_hidden_states.permute(0,2,1).contiguous()
414
+ return text_encoder_hidden_states, text_mask
415
+
416
+ def extract_energy_bar(self, input_audios):
417
+ if(input_audios.shape[-1] % self.num_samples_perseg > 0):
418
+ energy_bar = input_audios[:,:-1 * (input_audios.shape[-1] % self.num_samples_perseg)].reshape(input_audios.shape[0],-1,self.num_samples_perseg)
419
+ else:
420
+ energy_bar = input_audios.reshape(input_audios.shape[0],-1,self.num_samples_perseg)
421
+ energy_bar = (energy_bar.pow(2.0).mean(-1).sqrt() + 1e-6).log10() * 20 # B T
422
+ energy_bar = (energy_bar / 2.0 + 16).clamp(0,16).int()
423
+ energy_embedding = self.energy_embedding(energy_bar)
424
+ energy_embedding = energy_embedding.view(energy_embedding.shape[0], energy_embedding.shape[1] // 2, 2, 32).reshape(energy_embedding.shape[0], energy_embedding.shape[1] // 2, 64).permute(0,2,1) # b 128 t
425
+ return energy_embedding
426
+
427
+ def forward(self, input_audios, lyric, latents, latent_masks, validation_mode=False, \
428
+ additional_feats = ['spk', 'lyric'], \
429
+ train_rvq=True, train_ssl=False,layer=5):
430
+ if not hasattr(self,"device"):
431
+ self.device = input_audios.device
432
+ if not hasattr(self,"dtype"):
433
+ self.dtype = input_audios.dtype
434
+ device = self.device
435
+ input_audio_0 = input_audios[:,0,:]
436
+ input_audio_1 = input_audios[:,1,:]
437
+ input_audio_0 = self.preprocess_audio(input_audio_0)
438
+ input_audio_1 = self.preprocess_audio(input_audio_1)
439
+ input_audios_wav2vec = (input_audio_0 + input_audio_1) / 2.0
440
+ # energy_embedding = self.extract_energy_bar(input_audios)
441
+ # print("energy_embedding.shape:",energy_embedding.shape)
442
+ # with autocast(enabled=False):
443
+ if(train_ssl):
444
+ self.wav2vec.train()
445
+ wav2vec_embeds = self.extract_wav2vec_embeds(input_audios)
446
+ self.clap_embd_extractor.train()
447
+ prompt_embeds = self.extract_mert_embeds(input_audios)
448
+ if('spk' in additional_feats):
449
+ self.xvecmodel.train()
450
+ spk_embeds = self.extract_spk_embeds(input_audios).repeat(1,1,prompt_embeds.shape[-1]//2,1)
451
+ else:
452
+ with torch.no_grad():
453
+ with autocast(enabled=False):
454
+ bestrq_emb = self.extract_bestrq_embeds(input_audio_0,input_audio_1,layer)
455
+ # mert_emb = self.extract_mert_embeds(input_audios_mert)
456
+
457
+ wav2vec_embeds = self.extract_wav2vec_embeds(input_audios_wav2vec,bestrq_emb.shape[2])
458
+
459
+ bestrq_emb = bestrq_emb.detach()
460
+ if('lyric' in additional_feats):
461
+ text_encoder_hidden_states, text_mask = self.extract_lyric_feats(lyric)
462
+ else:
463
+ text_encoder_hidden_states, text_mask = None, None
464
+
465
+ # prompt_embeds_13 = torch.cat([mert_emb_13, energy_embedding], 1)
466
+ # print("prompt_embes.shape:",prompt_embeds.shape)
467
+ #prompt_embes.shape: torch.Size([3, 1088, 896])
468
+ # print("wav2vec_embeds.shape:",wav2vec_embeds.shape)
469
+ #wav2vec_embeds.shape:torch.Size([3, 1024, 896])
470
+ if(train_rvq):
471
+ quantized_bestrq_emb, _, _, commitment_loss_bestrq_emb, codebook_loss_bestrq_emb,_ = self.rvq_bestrq_emb(bestrq_emb) # b,d,t
472
+ else:
473
+ bestrq_emb = bestrq_emb.float()
474
+ self.rvq_bestrq_emb.eval()
475
+ # with autocast(enabled=False):
476
+ quantized_bestrq_emb, _, _, commitment_loss_bestrq_emb, codebook_loss_bestrq_emb,_ = self.rvq_bestrq_emb(bestrq_emb) # b,d,t
477
+ commitment_loss_bestrq_emb = commitment_loss_bestrq_emb.detach()
478
+ codebook_loss_bestrq_emb = codebook_loss_bestrq_emb.detach()
479
+ quantized_bestrq_emb = quantized_bestrq_emb.detach()
480
+
481
+ commitment_loss = commitment_loss_bestrq_emb
482
+ codebook_loss = codebook_loss_bestrq_emb
483
+
484
+
485
+ alpha=1
486
+ quantized_bestrq_emb = quantized_bestrq_emb * alpha + bestrq_emb * (1-alpha)
487
+
488
+ # print("quantized_bestrq_emb.shape:",quantized_bestrq_emb.shape)
489
+ # print("latent_masks.shape:",latent_masks.shape)
490
+ # quantized_bestrq_emb = torch.nn.functional.interpolate(quantized_bestrq_emb, size=(int(quantized_bestrq_emb.shape[-1]/999*937),), mode='linear', align_corners=True)
491
+
492
+
493
+
494
+ scenario = np.random.choice(['start_seg', 'other_seg'])
495
+ if(scenario == 'other_seg'):
496
+ for binx in range(input_audios.shape[0]):
497
+ # latent_masks[binx,0:64] = 1
498
+ latent_masks[binx,0:random.randint(64,128)] = 1
499
+ quantized_bestrq_emb = quantized_bestrq_emb.permute(0,2,1).contiguous()
500
+ # print("quantized_bestrq_emb.shape:",quantized_bestrq_emb.shape)
501
+ # print("quantized_bestrq_emb1.shape:",quantized_bestrq_emb.shape)
502
+ # print("latent_masks.shape:",latent_masks.shape)
503
+ quantized_bestrq_emb = (latent_masks > 0.5).unsqueeze(-1) * quantized_bestrq_emb \
504
+ + (latent_masks < 0.5).unsqueeze(-1) * self.zero_cond_embedding1.reshape(1,1,1024)
505
+
506
+
507
+
508
+
509
+ if self.uncondition:
510
+ mask_indices = [k for k in range(quantized_bestrq_emb.shape[0]) if random.random() < 0.1]
511
+ if len(mask_indices) > 0:
512
+ quantized_bestrq_emb[mask_indices] = 0
513
+ # print("latents.shape:",latents.shape)
514
+ latents = latents.permute(0,2,1).contiguous()
515
+ latents = self.normfeat.project_sample(latents)
516
+ latents = latents.permute(0,2,1).contiguous()
517
+ incontext_latents = latents * ((latent_masks > 0.5) * (latent_masks < 1.5)).unsqueeze(-1).float()
518
+ attention_mask=(latent_masks > 0.5)
519
+ B, L = attention_mask.size()
520
+ attention_mask = attention_mask.view(B, 1, L)
521
+ attention_mask = attention_mask * attention_mask.transpose(-1, -2)
522
+ attention_mask = attention_mask.unsqueeze(1)
523
+ # print("incontext_latents.shape:",incontext_latents.shape)
524
+ # print("quantized_bestrq_emb.shape:",quantized_bestrq_emb.shape)
525
+ latent_mask_input = self.mask_emb(latent_masks)
526
+ #64+48+64+1024
527
+ loss,loss_re, loss_cos = self.cfm_wrapper.compute_loss(latents, [latent_mask_input,incontext_latents, quantized_bestrq_emb], latent_masks,attention_mask,wav2vec_embeds, validation_mode=validation_mode)
528
+ return loss,loss_re, loss_cos, commitment_loss.mean(), codebook_loss.mean()
529
+
530
+ def init_device_dtype(self, device, dtype):
531
+ self.device = device
532
+ self.dtype = dtype
533
+
534
+ @torch.no_grad()
535
+ def fetch_codes(self, input_audios, additional_feats,layer):
536
+ input_audio_0 = input_audios[[0],:]
537
+ input_audio_1 = input_audios[[1],:]
538
+ input_audio_0 = self.preprocess_audio(input_audio_0)
539
+ input_audio_1 = self.preprocess_audio(input_audio_1)
540
+
541
+ self.bestrq.eval()
542
+
543
+ # bestrq_middle,bestrq_last = self.extract_bestrq_embeds(input_audios)
544
+ # bestrq_middle = bestrq_middle.detach()
545
+ # bestrq_last = bestrq_last.detach()
546
+ bestrq_emb = self.extract_bestrq_embeds(input_audio_0,input_audio_1,layer)
547
+ bestrq_emb = bestrq_emb.detach()
548
+
549
+ # self.rvq_bestrq_middle.eval()
550
+ # quantized_bestrq_middle, codes_bestrq_middle, *_ = self.rvq_bestrq_middle(bestrq_middle) # b,d,t
551
+ # self.rvq_bestrq_last.eval()
552
+ # quantized_bestrq_last, codes_bestrq_last, *_ = self.rvq_bestrq_last(bestrq_last) # b,d,t
553
+
554
+ self.rvq_bestrq_emb.eval()
555
+ quantized_bestrq_emb, codes_bestrq_emb, *_ = self.rvq_bestrq_emb(bestrq_emb) # b,d,t
556
+
557
+
558
+ if('spk' in additional_feats):
559
+ self.xvecmodel.eval()
560
+ spk_embeds = self.extract_spk_embeds(input_audios)
561
+ else:
562
+ spk_embeds = None
563
+
564
+ # return [codes_prompt, codes_wav2vec], [prompt_embeds, wav2vec_embeds], spk_embeds
565
+ # return [codes_prompt_7, codes_prompt_13, codes_prompt_20, codes_wav2vec_half, codes_wav2vec_last], [prompt_embeds_7, prompt_embeds_13, prompt_embeds_20, wav2vec_embeds_half, wav2vec_embeds_last], spk_embeds
566
+ # return [codes_bestrq_middle, codes_bestrq_last], [bestrq_middle, bestrq_last], spk_embeds
567
+ return [codes_bestrq_emb], [bestrq_emb], spk_embeds
568
+ # return [codes_prompt_13, codes_wav2vec_last], [prompt_embeds_13, wav2vec_embeds_last], spk_embeds
569
+
570
+
571
+ @torch.no_grad()
572
+ def fetch_codes_batch(self, input_audios, additional_feats,layer):
573
+ input_audio_0 = input_audios[:,0,:]
574
+ input_audio_1 = input_audios[:,1,:]
575
+ input_audio_0 = self.preprocess_audio(input_audio_0)
576
+ input_audio_1 = self.preprocess_audio(input_audio_1)
577
+
578
+ self.bestrq.eval()
579
+
580
+ # bestrq_middle,bestrq_last = self.extract_bestrq_embeds(input_audios)
581
+ # bestrq_middle = bestrq_middle.detach()
582
+ # bestrq_last = bestrq_last.detach()
583
+ bestrq_emb = self.extract_bestrq_embeds(input_audio_0,input_audio_1,layer)
584
+ bestrq_emb = bestrq_emb.detach()
585
+
586
+ # self.rvq_bestrq_middle.eval()
587
+ # quantized_bestrq_middle, codes_bestrq_middle, *_ = self.rvq_bestrq_middle(bestrq_middle) # b,d,t
588
+ # self.rvq_bestrq_last.eval()
589
+ # quantized_bestrq_last, codes_bestrq_last, *_ = self.rvq_bestrq_last(bestrq_last) # b,d,t
590
+
591
+ self.rvq_bestrq_emb.eval()
592
+ quantized_bestrq_emb, codes_bestrq_emb, *_ = self.rvq_bestrq_emb(bestrq_emb) # b,d,t
593
+
594
+
595
+ if('spk' in additional_feats):
596
+ self.xvecmodel.eval()
597
+ spk_embeds = self.extract_spk_embeds(input_audios)
598
+ else:
599
+ spk_embeds = None
600
+
601
+ # return [codes_prompt, codes_wav2vec], [prompt_embeds, wav2vec_embeds], spk_embeds
602
+ # return [codes_prompt_7, codes_prompt_13, codes_prompt_20, codes_wav2vec_half, codes_wav2vec_last], [prompt_embeds_7, prompt_embeds_13, prompt_embeds_20, wav2vec_embeds_half, wav2vec_embeds_last], spk_embeds
603
+ # return [codes_bestrq_middle, codes_bestrq_last], [bestrq_middle, bestrq_last], spk_embeds
604
+ return [codes_bestrq_emb], [bestrq_emb], spk_embeds
605
+
606
+ @torch.no_grad()
607
+ def inference_codes(self, codes, spk_embeds, true_latents, latent_length, additional_feats, incontext_length=127,
608
+ guidance_scale=2, num_steps=20,
609
+ disable_progress=True, scenario='start_seg'):
610
+ classifier_free_guidance = guidance_scale > 1.0
611
+ device = self.device
612
+ dtype = self.dtype
613
+ # codes_bestrq_middle, codes_bestrq_last = codes
614
+ codes_bestrq_emb = codes[0]
615
+
616
+
617
+ batch_size = codes_bestrq_emb.shape[0]
618
+
619
+
620
+ quantized_bestrq_emb,_,_=self.rvq_bestrq_emb.from_codes(codes_bestrq_emb)
621
+ # quantized_bestrq_emb = torch.nn.functional.interpolate(quantized_bestrq_emb, size=(int(quantized_bestrq_emb.shape[-1]/999*937),), mode='linear', align_corners=True)
622
+ quantized_bestrq_emb = quantized_bestrq_emb.permute(0,2,1).contiguous()
623
+ print("quantized_bestrq_emb.shape:",quantized_bestrq_emb.shape)
624
+ # quantized_bestrq_emb = torch.nn.functional.interpolate(quantized_bestrq_emb, size=(int(quantized_bestrq_emb.shape[-1]/999*937),), mode='linear', align_corners=True)
625
+
626
+
627
+
628
+
629
+ if('spk' in additional_feats):
630
+ spk_embeds = spk_embeds.repeat(1,1,quantized_bestrq_emb.shape[-2],1).detach()
631
+
632
+ num_frames = quantized_bestrq_emb.shape[1]
633
+
634
+ num_channels_latents = self.num_channels
635
+ shape = (batch_size, num_frames, 64)
636
+ latents = randn_tensor(shape, generator=None, device=device, dtype=dtype)
637
+
638
+
639
+
640
+ latent_masks = torch.zeros(latents.shape[0], latents.shape[1], dtype=torch.int64, device=latents.device)
641
+ latent_masks[:,0:latent_length] = 2
642
+ if(scenario=='other_seg'):
643
+ latent_masks[:,0:incontext_length] = 1
644
+
645
+
646
+
647
+ quantized_bestrq_emb = (latent_masks > 0.5).unsqueeze(-1) * quantized_bestrq_emb \
648
+ + (latent_masks < 0.5).unsqueeze(-1) * self.zero_cond_embedding1.reshape(1,1,1024)
649
+ true_latents = true_latents.permute(0,2,1).contiguous()
650
+ true_latents = self.normfeat.project_sample(true_latents)
651
+ true_latents = true_latents.permute(0,2,1).contiguous()
652
+ incontext_latents = true_latents * ((latent_masks > 0.5) * (latent_masks < 1.5)).unsqueeze(-1).float()
653
+ incontext_length = ((latent_masks > 0.5) * (latent_masks < 1.5)).sum(-1)[0]
654
+
655
+
656
+ attention_mask=(latent_masks > 0.5)
657
+ B, L = attention_mask.size()
658
+ attention_mask = attention_mask.view(B, 1, L)
659
+ attention_mask = attention_mask * attention_mask.transpose(-1, -2)
660
+ attention_mask = attention_mask.unsqueeze(1)
661
+ latent_mask_input = self.mask_emb(latent_masks)
662
+
663
+ if('spk' in additional_feats):
664
+ # additional_model_input = torch.cat([quantized_bestrq_middle, quantized_bestrq_last, spk_embeds],1)
665
+ additional_model_input = torch.cat([quantized_bestrq_emb, spk_embeds],1)
666
+ else:
667
+ # additional_model_input = torch.cat([quantized_bestrq_middle, quantized_bestrq_last],1)
668
+ additional_model_input = torch.cat([quantized_bestrq_emb],1)
669
+
670
+ temperature = 1.0
671
+ t_span = torch.linspace(0, 1, num_steps + 1, device=quantized_bestrq_emb.device)
672
+ latents = self.cfm_wrapper.solve_euler(latents * temperature, latent_mask_input,incontext_latents, incontext_length, t_span, additional_model_input,attention_mask, guidance_scale)
673
+
674
+ latents[:,0:incontext_length,:] = incontext_latents[:,0:incontext_length,:]
675
+ latents = latents.permute(0,2,1).contiguous()
676
+ latents = self.normfeat.return_sample(latents)
677
+ # latents = latents.permute(0,2,1).contiguous()
678
+ return latents
679
+
680
+ @torch.no_grad()
681
+ def inference(self, input_audios, lyric, true_latents, latent_length, additional_feats, guidance_scale=2, num_steps=20,
682
+ disable_progress=True,layer=5,scenario='start_seg'):
683
+ codes, embeds, spk_embeds = self.fetch_codes(input_audios, additional_feats,layer)
684
+
685
+ latents = self.inference_codes(codes, spk_embeds, true_latents, latent_length, additional_feats, \
686
+ guidance_scale=guidance_scale, num_steps=num_steps, \
687
+ disable_progress=disable_progress,scenario=scenario)
688
+ return latents
689
+
690
+ @torch.no_grad()
691
+ def inference_rtf(self, input_audios, lyric, true_latents, latent_length, additional_feats, guidance_scale=2, num_steps=20,
692
+ disable_progress=True,layer=5,scenario='start_seg'):
693
+ codes, embeds, spk_embeds = self.fetch_codes(input_audios, additional_feats,layer)
694
+ import time
695
+ start = time.time()
696
+ latents = self.inference_codes(codes, spk_embeds, true_latents, latent_length, additional_feats, \
697
+ guidance_scale=guidance_scale, num_steps=num_steps, \
698
+ disable_progress=disable_progress,scenario=scenario)
699
+ return latents,time.time()-start
700
+
701
+ def prepare_latents(self, batch_size, num_frames, num_channels_latents, dtype, device):
702
+ divisor = 4
703
+ shape = (batch_size, num_channels_latents, num_frames, 32)
704
+ if(num_frames%divisor>0):
705
+ num_frames = round(num_frames/float(divisor))*divisor
706
+ shape = (batch_size, num_channels_latents, num_frames, 32)
707
+ latents = randn_tensor(shape, generator=None, device=device, dtype=dtype)
708
+ return latents
709
+
710
+
codeclm/tokenizer/Flow1dVAE/model_2rvq.py CHANGED
@@ -1,774 +1,774 @@
1
- import yaml
2
- import random
3
- import inspect
4
- import numpy as np
5
- from tqdm import tqdm
6
- import typing as tp
7
- from abc import ABC
8
-
9
- import torch
10
- import torch.nn as nn
11
- import torch.nn.functional as F
12
- import torchaudio
13
-
14
- from einops import repeat
15
- from tools.torch_tools import wav_to_fbank
16
-
17
- import diffusers
18
- from diffusers.utils.torch_utils import randn_tensor
19
- from diffusers import DDPMScheduler
20
- from models.transformer_2d_flow import Transformer2DModel
21
- from transformers import AutoFeatureExtractor, Wav2Vec2BertModel,HubertModel
22
- # from tools.get_mulan import get_mulan
23
- from third_party.wespeaker.extract_embd import XVECModel
24
- # from libs.rvq2 import RVQEmbedding
25
- from libs.rvq.descript_quantize3_4layer_freezelayer1 import ResidualVectorQuantize
26
-
27
- from models_gpt.models.gpt2_rope2_time_new_correct_mask_noncasual_reflow import GPT2Model
28
- from models_gpt.models.gpt2_config import GPT2Config
29
-
30
- from torch.cuda.amp import autocast
31
-
32
-
33
- from our_MERT_BESTRQ.test import load_model
34
-
35
- class HubertModelWithFinalProj(HubertModel):
36
- def __init__(self, config):
37
- super().__init__(config)
38
-
39
- # The final projection layer is only used for backward compatibility.
40
- # Following https://github.com/auspicious3000/contentvec/issues/6
41
- # Remove this layer is necessary to achieve the desired outcome.
42
- print("hidden_size:",config.hidden_size)
43
- print("classifier_proj_size:",config.classifier_proj_size)
44
- self.final_proj = nn.Linear(config.hidden_size, config.classifier_proj_size)
45
-
46
-
47
- class SampleProcessor(torch.nn.Module):
48
- def project_sample(self, x: torch.Tensor):
49
- """Project the original sample to the 'space' where the diffusion will happen."""
50
- """Project back from diffusion space to the actual sample space."""
51
- return z
52
-
53
- class Feature1DProcessor(SampleProcessor):
54
- def __init__(self, dim: int = 100, power_std = 1., \
55
- num_samples: int = 100_000, cal_num_frames: int = 600):
56
- super().__init__()
57
-
58
- self.num_samples = num_samples
59
- self.dim = dim
60
- self.power_std = power_std
61
- self.cal_num_frames = cal_num_frames
62
- self.register_buffer('counts', torch.zeros(1))
63
- self.register_buffer('sum_x', torch.zeros(dim))
64
- self.register_buffer('sum_x2', torch.zeros(dim))
65
- self.register_buffer('sum_target_x2', torch.zeros(dim))
66
- self.counts: torch.Tensor
67
- self.sum_x: torch.Tensor
68
- self.sum_x2: torch.Tensor
69
-
70
- @property
71
- def mean(self):
72
- mean = self.sum_x / self.counts
73
- if(self.counts < 10):
74
- mean = torch.zeros_like(mean)
75
- return mean
76
-
77
- @property
78
- def std(self):
79
- std = (self.sum_x2 / self.counts - self.mean**2).clamp(min=0).sqrt()
80
- if(self.counts < 10):
81
- std = torch.ones_like(std)
82
- return std
83
-
84
- @property
85
- def target_std(self):
86
- return 1
87
-
88
- def project_sample(self, x: torch.Tensor):
89
- assert x.dim() == 3
90
- if self.counts.item() < self.num_samples:
91
- self.counts += len(x)
92
- self.sum_x += x[:,:,0:self.cal_num_frames].mean(dim=(2,)).sum(dim=0)
93
- self.sum_x2 += x[:,:,0:self.cal_num_frames].pow(2).mean(dim=(2,)).sum(dim=0)
94
- rescale = (self.target_std / self.std.clamp(min=1e-12)) ** self.power_std # same output size
95
- x = (x - self.mean.view(1, -1, 1)) * rescale.view(1, -1, 1)
96
- return x
97
-
98
- def return_sample(self, x: torch.Tensor):
99
- assert x.dim() == 3
100
- rescale = (self.std / self.target_std) ** self.power_std
101
- # print(rescale, self.mean)
102
- x = x * rescale.view(1, -1, 1) + self.mean.view(1, -1, 1)
103
- return x
104
-
105
- def pad_or_tunc_tolen(prior_text_encoder_hidden_states, prior_text_mask, prior_prompt_embeds, len_size=77):
106
- if(prior_text_encoder_hidden_states.shape[1]<len_size):
107
- prior_text_encoder_hidden_states = torch.cat([prior_text_encoder_hidden_states, \
108
- torch.zeros(prior_text_mask.shape[0], len_size-prior_text_mask.shape[1], \
109
- prior_text_encoder_hidden_states.shape[2], device=prior_text_mask.device, \
110
- dtype=prior_text_encoder_hidden_states.dtype)],1)
111
- prior_text_mask = torch.cat([prior_text_mask, torch.zeros(prior_text_mask.shape[0], len_size-prior_text_mask.shape[1], device=prior_text_mask.device, dtype=prior_text_mask.dtype)],1)
112
- else:
113
- prior_text_encoder_hidden_states = prior_text_encoder_hidden_states[:,0:len_size]
114
- prior_text_mask = prior_text_mask[:,0:len_size]
115
- prior_text_encoder_hidden_states = prior_text_encoder_hidden_states.permute(0,2,1).contiguous()
116
- return prior_text_encoder_hidden_states, prior_text_mask, prior_prompt_embeds
117
-
118
- class BASECFM(torch.nn.Module, ABC):
119
- def __init__(
120
- self,
121
- estimator,
122
- mlp,
123
- ssl_layer
124
- ):
125
- super().__init__()
126
- self.sigma_min = 1e-4
127
-
128
- self.estimator = estimator
129
- self.mlp = mlp
130
- self.ssl_layer = ssl_layer
131
-
132
- @torch.inference_mode()
133
- def forward(self, mu, n_timesteps, temperature=1.0):
134
- """Forward diffusion
135
-
136
- Args:
137
- mu (torch.Tensor): output of encoder
138
- shape: (batch_size, n_channels, mel_timesteps, n_feats)
139
- n_timesteps (int): number of diffusion steps
140
- temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
141
-
142
- Returns:
143
- sample: generated mel-spectrogram
144
- shape: (batch_size, n_channels, mel_timesteps, n_feats)
145
- """
146
- z = torch.randn_like(mu) * temperature
147
- t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device)
148
- return self.solve_euler(z, t_span=t_span)
149
-
150
- def solve_euler(self, x, latent_mask_input,incontext_x, incontext_length, t_span, mu,attention_mask, guidance_scale):
151
- """
152
- Fixed euler solver for ODEs.
153
- Args:
154
- x (torch.Tensor): random noise
155
- t_span (torch.Tensor): n_timesteps interpolated
156
- shape: (n_timesteps + 1,)
157
- mu (torch.Tensor): output of encoder
158
- shape: (batch_size, n_channels, mel_timesteps, n_feats)
159
- """
160
- t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
161
- noise = x.clone()
162
-
163
- # I am storing this because I can later plot it by putting a debugger here and saving it to a file
164
- # Or in future might add like a return_all_steps flag
165
- sol = []
166
-
167
- for step in tqdm(range(1, len(t_span))):
168
- # print("incontext_x.shape:",incontext_x.shape)
169
- # print("noise.shape:",noise.shape)
170
- # print("t.shape:",t.shape)
171
- x[:,0:incontext_length,:] = (1 - (1 - self.sigma_min) * t) * noise[:,0:incontext_length,:] + t * incontext_x[:,0:incontext_length,:]
172
- if(guidance_scale > 1.0):
173
-
174
- model_input = torch.cat([ \
175
- torch.cat([latent_mask_input, latent_mask_input], 0), \
176
- torch.cat([incontext_x, incontext_x], 0), \
177
- torch.cat([torch.zeros_like(mu), mu], 0), \
178
- torch.cat([x, x], 0), \
179
- ], 2)
180
- timestep=t.unsqueeze(-1).repeat(2)
181
-
182
- dphi_dt = self.estimator(inputs_embeds=model_input, attention_mask=attention_mask,time_step=timestep).last_hidden_state
183
- dphi_dt_uncond, dhpi_dt_cond = dphi_dt.chunk(2,0)
184
- dphi_dt = dphi_dt_uncond + guidance_scale * (dhpi_dt_cond - dphi_dt_uncond)
185
- else:
186
- model_input = torch.cat([latent_mask_input, incontext_x, mu, x], 2)
187
- timestep=t.unsqueeze(-1)
188
- dphi_dt = self.estimator(inputs_embeds=model_input, attention_mask=attention_mask,time_step=timestep).last_hidden_state
189
-
190
- dphi_dt = dphi_dt[: ,:, -x.shape[2]:]
191
- # print("dphi_dt.shape:",dphi_dt.shape)
192
- # print("x.shape:",x.shape)
193
-
194
- x = x + dt * dphi_dt
195
- t = t + dt
196
- sol.append(x)
197
- if step < len(t_span) - 1:
198
- dt = t_span[step + 1] - t
199
-
200
- return sol[-1]
201
-
202
- def projection_loss(self,hidden_proj, bestrq_emb):
203
- bsz = hidden_proj.shape[0]
204
-
205
- hidden_proj_normalized = F.normalize(hidden_proj, dim=-1)
206
- bestrq_emb_normalized = F.normalize(bestrq_emb, dim=-1)
207
-
208
- proj_loss = -(hidden_proj_normalized * bestrq_emb_normalized).sum(dim=-1)
209
- proj_loss = 1+proj_loss.mean()
210
-
211
- return proj_loss
212
-
213
- def compute_loss(self, x1, mu, latent_masks,attention_mask,wav2vec_embeds, validation_mode=False):
214
- """Computes diffusion loss
215
-
216
- Args:
217
- x1 (torch.Tensor): Target
218
- shape: (batch_size, n_channels, mel_timesteps, n_feats)
219
- mu (torch.Tensor): output of encoder
220
- shape: (batch_size, n_channels, mel_timesteps, n_feats)
221
-
222
- Returns:
223
- loss: conditional flow matching loss
224
- y: conditional flow
225
- shape: (batch_size, n_channels, mel_timesteps, n_feats)
226
- """
227
- b = mu[0].shape[0]
228
- len_x = x1.shape[2]
229
- # random timestep
230
- if(validation_mode):
231
- t = torch.ones([b, 1, 1], device=mu[0].device, dtype=mu[0].dtype) * 0.5
232
- else:
233
- t = torch.rand([b, 1, 1], device=mu[0].device, dtype=mu[0].dtype)
234
- # sample noise p(x_0)
235
- z = torch.randn_like(x1)
236
-
237
- y = (1 - (1 - self.sigma_min) * t) * z + t * x1
238
- u = x1 - (1 - self.sigma_min) * z
239
- # print("y.shape:",y.shape)
240
- #self.unet(inputs_embeds=model_input, attention_mask=attention_mask,encoder_hidden_states=text_embedding,encoder_attention_mask=txt_attn_mask,time_step=timesteps).last_hidden_state
241
- model_input = torch.cat([*mu,y], 2)
242
- t=t.squeeze(-1).squeeze(-1)
243
- # print("model_input.shape:",model_input.shape)
244
- # print("attention_mask.shape:",attention_mask.shape)
245
- out = self.estimator(inputs_embeds=model_input, attention_mask=attention_mask,time_step=t,output_hidden_states=True)
246
- hidden_layer = out.hidden_states[self.ssl_layer]
247
- hidden_proj = self.mlp(hidden_layer)
248
- # print("hidden_proj.shape:",hidden_proj.shape)
249
- # print("mert_emb.shape:",mert_emb.shape)
250
- # exit()
251
-
252
-
253
- out = out.last_hidden_state
254
-
255
- out=out[:,:,-len_x:]
256
- # out=self.proj_out(out)
257
-
258
- weight = (latent_masks > 1.5).unsqueeze(-1).repeat(1, 1, out.shape[-1]).float() + (latent_masks < 0.5).unsqueeze(-1).repeat(1, 1, out.shape[-1]).float() * 0.01
259
- # print("out.shape",out.shape)
260
- # print("u.shape",u.shape)
261
- loss_re = F.mse_loss(out * weight, u * weight, reduction="sum") / weight.sum()
262
- # print("hidden_proj.shape:",hidden_proj.shape)
263
- # print("wav2vec_embeds.shape:",wav2vec_embeds.shape)
264
- loss_cos = self.projection_loss(hidden_proj, wav2vec_embeds)
265
- loss = loss_re + loss_cos * 0.5
266
- # print("loss_cos:",loss_cos,loss_cos.device)
267
- print("loss:",loss,loss.device)
268
- # exit()
269
- return loss, loss_re, loss_cos
270
-
271
- class PromptCondAudioDiffusion(nn.Module):
272
- def __init__(
273
- self,
274
- num_channels,
275
- unet_model_name=None,
276
- unet_model_config_path=None,
277
- snr_gamma=None,
278
- hubert_layer=None,
279
- ssl_layer=None,
280
- uncondition=True,
281
- out_paint=False,
282
- ):
283
- super().__init__()
284
-
285
- assert unet_model_name is not None or unet_model_config_path is not None, "Either UNet pretrain model name or a config file path is required"
286
-
287
- self.unet_model_name = unet_model_name
288
- self.unet_model_config_path = unet_model_config_path
289
- self.snr_gamma = snr_gamma
290
- self.uncondition = uncondition
291
- self.num_channels = num_channels
292
- self.hubert_layer = hubert_layer
293
- self.ssl_layer = ssl_layer
294
-
295
- # https://huggingface.co/docs/diffusers/v0.14.0/en/api/schedulers/overview
296
- self.normfeat = Feature1DProcessor(dim=64)
297
-
298
- self.sample_rate = 48000
299
- self.num_samples_perseg = self.sample_rate * 20 // 1000
300
- self.rsp48toclap = torchaudio.transforms.Resample(48000, 24000)
301
- self.rsq48towav2vec = torchaudio.transforms.Resample(48000, 16000)
302
- # self.wav2vec = Wav2Vec2BertModel.from_pretrained("facebook/w2v-bert-2.0", trust_remote_code=True)
303
- # self.wav2vec_processor = AutoFeatureExtractor.from_pretrained("facebook/w2v-bert-2.0", trust_remote_code=True)
304
- self.bestrq = load_model(
305
- model_dir='path/to/our-MERT/mert_fairseq',
306
- checkpoint_dir='checkpoint-120000.pt',
307
- )
308
- self.rsq48tobestrq = torchaudio.transforms.Resample(48000, 24000)
309
- self.rsq48tohubert = torchaudio.transforms.Resample(48000, 16000)
310
- for v in self.bestrq.parameters():v.requires_grad = False
311
- self.rvq_bestrq_emb = ResidualVectorQuantize(input_dim = 1024, n_codebooks = 2, codebook_size = 16_384, codebook_dim = 32, quantizer_dropout = 0.0, stale_tolerance=200)
312
- # for v in self.rvq_bestrq_emb.parameters():
313
- # print(v)
314
- freeze_parameters='quantizers.0'
315
- for name, param in self.rvq_bestrq_emb.named_parameters():
316
- if freeze_parameters in name:
317
- param.requires_grad = False
318
- print("Freezing RVQ parameters:", name)
319
- self.hubert = HubertModelWithFinalProj.from_pretrained("huggingface_cache/models--lengyue233--content-vec-best/snapshots/c0b9ba13db21beaa4053faae94c102ebe326fd68")
320
- for v in self.hubert.parameters():v.requires_grad = False
321
- self.zero_cond_embedding1 = nn.Parameter(torch.randn(32*32,))
322
- # self.xvecmodel = XVECModel()
323
- config = GPT2Config(n_positions=1000,n_layer=39,n_head=30,n_embd=1200)
324
- unet = GPT2Model(config)
325
- mlp = nn.Sequential(
326
- nn.Linear(1200, 1024),
327
- nn.SiLU(),
328
- nn.Linear(1024, 1024),
329
- nn.SiLU(),
330
- nn.Linear(1024, 768)
331
- )
332
- self.set_from = "random"
333
- self.cfm_wrapper = BASECFM(unet, mlp,self.ssl_layer)
334
- self.mask_emb = torch.nn.Embedding(3, 48)
335
- print("Transformer initialized from pretrain.")
336
- torch.cuda.empty_cache()
337
- # self.unet.set_attn_processor(AttnProcessor2_0())
338
- # self.unet.set_use_memory_efficient_attention_xformers(True)
339
-
340
- # self.start_embedding = nn.Parameter(torch.randn(1,1024))
341
- # self.end_embedding = nn.Parameter(torch.randn(1,1024))
342
-
343
- def compute_snr(self, timesteps):
344
- """
345
- Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
346
- """
347
- alphas_cumprod = self.noise_scheduler.alphas_cumprod
348
- sqrt_alphas_cumprod = alphas_cumprod**0.5
349
- sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
350
-
351
- # Expand the tensors.
352
- # Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026
353
- sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
354
- while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape):
355
- sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
356
- alpha = sqrt_alphas_cumprod.expand(timesteps.shape)
357
-
358
- sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
359
- while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape):
360
- sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None]
361
- sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape)
362
-
363
- # Compute SNR.
364
- snr = (alpha / sigma) ** 2
365
- return snr
366
-
367
- def preprocess_audio(self, input_audios, threshold=0.9):
368
- assert len(input_audios.shape) == 2, input_audios.shape
369
- norm_value = torch.ones_like(input_audios[:,0])
370
- max_volume = input_audios.abs().max(dim=-1)[0]
371
- norm_value[max_volume>threshold] = max_volume[max_volume>threshold] / threshold
372
- return input_audios/norm_value.unsqueeze(-1)
373
-
374
- def extract_wav2vec_embeds(self, input_audios,output_len):
375
- wav2vec_stride = 2
376
-
377
- wav2vec_embeds = self.hubert(self.rsq48tohubert(input_audios), output_hidden_states=True).hidden_states # 1, 4096, 1024
378
- # print(wav2vec_embeds)
379
- # print("audio.shape:",input_audios.shape)
380
- wav2vec_embeds_last=wav2vec_embeds[self.hubert_layer]
381
- # print("wav2vec_embeds_last.shape:",wav2vec_embeds_last.shape)
382
- wav2vec_embeds_last=torch.nn.functional.interpolate(wav2vec_embeds_last.permute(0, 2, 1), size=output_len, mode='linear', align_corners=False).permute(0, 2, 1)
383
- return wav2vec_embeds_last
384
-
385
- def extract_mert_embeds(self, input_audios):
386
- prompt_stride = 3
387
- inputs = self.clap_embd_extractor.mulan.audio.processor(self.rsp48toclap(input_audios), sampling_rate=self.clap_embd_extractor.mulan.audio.sr, return_tensors="pt")
388
- input_values = inputs['input_values'].squeeze(0).to(input_audios.device, dtype = input_audios.dtype)
389
- prompt_embeds = self.clap_embd_extractor.mulan.audio.model(input_values, output_hidden_states=True).hidden_states # batch_size, Time steps, 1024
390
- mert_emb= prompt_embeds[-1]
391
- mert_emb = torch.nn.functional.interpolate(mert_emb.permute(0, 2, 1), size=500, mode='linear', align_corners=False).permute(0, 2, 1)
392
-
393
- return mert_emb
394
-
395
- def extract_bestrq_embeds(self, input_audio_0,input_audio_1,layer):
396
- self.bestrq.eval()
397
- # print("audio shape:",input_audio_0.shape)
398
- input_wav_mean = (input_audio_0 + input_audio_1) / 2.0
399
- # print("input_wav_mean.shape:",input_wav_mean.shape)
400
- # input_wav_mean = torch.randn(2,1720320*2).to(input_audio_0.device)
401
- input_wav_mean = self.bestrq(self.rsq48tobestrq(input_wav_mean), features_only = True)
402
- layer_results = input_wav_mean['layer_results']
403
- # print("layer_results.shape:",layer_results[layer].shape)
404
- bestrq_emb = layer_results[layer]
405
- bestrq_emb = bestrq_emb.permute(0,2,1).contiguous()
406
- #[b,t,1024] t=t/960
407
- #35.84s->batch,896,1024
408
- return bestrq_emb
409
-
410
-
411
- def extract_spk_embeds(self, input_audios):
412
- spk_embeds = self.xvecmodel(self.rsq48towav2vec(input_audios))
413
- spk_embeds = self.spk_linear(spk_embeds).reshape(spk_embeds.shape[0], 16, 1, 32)
414
- return spk_embeds
415
-
416
- def extract_lyric_feats(self, lyric):
417
- with torch.no_grad():
418
- try:
419
- text_encoder_hidden_states, text_mask, text_prompt_embeds = self.clap_embd_extractor(texts = lyric, return_one=False)
420
- except:
421
- text_encoder_hidden_states, text_mask, text_prompt_embeds = self.clap_embd_extractor(texts = [""] * len(lyric), return_one=False)
422
- text_encoder_hidden_states = text_encoder_hidden_states.to(self.device)
423
- text_mask = text_mask.to(self.device)
424
- text_encoder_hidden_states, text_mask, text_prompt_embeds = \
425
- pad_or_tunc_tolen(text_encoder_hidden_states, text_mask, text_prompt_embeds)
426
- text_encoder_hidden_states = text_encoder_hidden_states.permute(0,2,1).contiguous()
427
- return text_encoder_hidden_states, text_mask
428
-
429
- def extract_energy_bar(self, input_audios):
430
- if(input_audios.shape[-1] % self.num_samples_perseg > 0):
431
- energy_bar = input_audios[:,:-1 * (input_audios.shape[-1] % self.num_samples_perseg)].reshape(input_audios.shape[0],-1,self.num_samples_perseg)
432
- else:
433
- energy_bar = input_audios.reshape(input_audios.shape[0],-1,self.num_samples_perseg)
434
- energy_bar = (energy_bar.pow(2.0).mean(-1).sqrt() + 1e-6).log10() * 20 # B T
435
- energy_bar = (energy_bar / 2.0 + 16).clamp(0,16).int()
436
- energy_embedding = self.energy_embedding(energy_bar)
437
- energy_embedding = energy_embedding.view(energy_embedding.shape[0], energy_embedding.shape[1] // 2, 2, 32).reshape(energy_embedding.shape[0], energy_embedding.shape[1] // 2, 64).permute(0,2,1) # b 128 t
438
- return energy_embedding
439
-
440
- def forward(self, input_audios, lyric, latents, latent_masks, validation_mode=False, \
441
- additional_feats = ['spk', 'lyric'], \
442
- train_rvq=True, train_ssl=False,layer=5):
443
- if not hasattr(self,"device"):
444
- self.device = input_audios.device
445
- if not hasattr(self,"dtype"):
446
- self.dtype = input_audios.dtype
447
- device = self.device
448
- input_audio_0 = input_audios[:,0,:]
449
- input_audio_1 = input_audios[:,1,:]
450
- input_audio_0 = self.preprocess_audio(input_audio_0)
451
- input_audio_1 = self.preprocess_audio(input_audio_1)
452
- input_audios_wav2vec = (input_audio_0 + input_audio_1) / 2.0
453
- # energy_embedding = self.extract_energy_bar(input_audios)
454
- # print("energy_embedding.shape:",energy_embedding.shape)
455
- # with autocast(enabled=False):
456
- if(train_ssl):
457
- self.wav2vec.train()
458
- wav2vec_embeds = self.extract_wav2vec_embeds(input_audios)
459
- self.clap_embd_extractor.train()
460
- prompt_embeds = self.extract_mert_embeds(input_audios)
461
- if('spk' in additional_feats):
462
- self.xvecmodel.train()
463
- spk_embeds = self.extract_spk_embeds(input_audios).repeat(1,1,prompt_embeds.shape[-1]//2,1)
464
- else:
465
- with torch.no_grad():
466
- with autocast(enabled=False):
467
- bestrq_emb = self.extract_bestrq_embeds(input_audio_0,input_audio_1,layer)
468
- # mert_emb = self.extract_mert_embeds(input_audios_mert)
469
-
470
- wav2vec_embeds = self.extract_wav2vec_embeds(input_audios_wav2vec,bestrq_emb.shape[2])
471
-
472
- bestrq_emb = bestrq_emb.detach()
473
- if('lyric' in additional_feats):
474
- text_encoder_hidden_states, text_mask = self.extract_lyric_feats(lyric)
475
- else:
476
- text_encoder_hidden_states, text_mask = None, None
477
-
478
-
479
- if(train_rvq):
480
- random_num=random.random()
481
- if(random_num<0.6):
482
- rvq_layer = 1
483
- elif(random_num<0.8):
484
- rvq_layer = 2
485
- else:
486
- rvq_layer = 4
487
- quantized_bestrq_emb, _, _, commitment_loss_bestrq_emb, codebook_loss_bestrq_emb,_ = self.rvq_bestrq_emb(bestrq_emb,n_quantizers=rvq_layer) # b,d,t
488
- else:
489
- bestrq_emb = bestrq_emb.float()
490
- self.rvq_bestrq_emb.eval()
491
- # with autocast(enabled=False):
492
- quantized_bestrq_emb, _, _, commitment_loss_bestrq_emb, codebook_loss_bestrq_emb,_ = self.rvq_bestrq_emb(bestrq_emb) # b,d,t
493
- commitment_loss_bestrq_emb = commitment_loss_bestrq_emb.detach()
494
- codebook_loss_bestrq_emb = codebook_loss_bestrq_emb.detach()
495
- quantized_bestrq_emb = quantized_bestrq_emb.detach()
496
-
497
- commitment_loss = commitment_loss_bestrq_emb
498
- codebook_loss = codebook_loss_bestrq_emb
499
-
500
-
501
- alpha=1
502
- quantized_bestrq_emb = quantized_bestrq_emb * alpha + bestrq_emb * (1-alpha)
503
-
504
- # print("quantized_bestrq_emb.shape:",quantized_bestrq_emb.shape)
505
- # print("latent_masks.shape:",latent_masks.shape)
506
- # quantized_bestrq_emb = torch.nn.functional.interpolate(quantized_bestrq_emb, size=(int(quantized_bestrq_emb.shape[-1]/999*937),), mode='linear', align_corners=True)
507
-
508
-
509
-
510
- scenario = np.random.choice(['start_seg', 'other_seg'])
511
- if(scenario == 'other_seg'):
512
- for binx in range(input_audios.shape[0]):
513
- # latent_masks[binx,0:64] = 1
514
- latent_masks[binx,0:random.randint(64,128)] = 1
515
- quantized_bestrq_emb = quantized_bestrq_emb.permute(0,2,1).contiguous()
516
- # print("quantized_bestrq_emb.shape:",quantized_bestrq_emb.shape)
517
- # print("quantized_bestrq_emb1.shape:",quantized_bestrq_emb.shape)
518
- # print("latent_masks.shape:",latent_masks.shape)
519
- quantized_bestrq_emb = (latent_masks > 0.5).unsqueeze(-1) * quantized_bestrq_emb \
520
- + (latent_masks < 0.5).unsqueeze(-1) * self.zero_cond_embedding1.reshape(1,1,1024)
521
-
522
-
523
-
524
-
525
- if self.uncondition:
526
- mask_indices = [k for k in range(quantized_bestrq_emb.shape[0]) if random.random() < 0.1]
527
- if len(mask_indices) > 0:
528
- quantized_bestrq_emb[mask_indices] = 0
529
- # print("latents.shape:",latents.shape)
530
- latents = latents.permute(0,2,1).contiguous()
531
- latents = self.normfeat.project_sample(latents)
532
- latents = latents.permute(0,2,1).contiguous()
533
- incontext_latents = latents * ((latent_masks > 0.5) * (latent_masks < 1.5)).unsqueeze(-1).float()
534
- attention_mask=(latent_masks > 0.5)
535
- B, L = attention_mask.size()
536
- attention_mask = attention_mask.view(B, 1, L)
537
- attention_mask = attention_mask * attention_mask.transpose(-1, -2)
538
- attention_mask = attention_mask.unsqueeze(1)
539
- # print("incontext_latents.shape:",incontext_latents.shape)
540
- # print("quantized_bestrq_emb.shape:",quantized_bestrq_emb.shape)
541
- latent_mask_input = self.mask_emb(latent_masks)
542
- #64+48+64+1024
543
- loss,loss_re, loss_cos = self.cfm_wrapper.compute_loss(latents, [latent_mask_input,incontext_latents, quantized_bestrq_emb], latent_masks,attention_mask,wav2vec_embeds, validation_mode=validation_mode)
544
- return loss,loss_re, loss_cos, commitment_loss.mean(), codebook_loss.mean()
545
-
546
- def init_device_dtype(self, device, dtype):
547
- self.device = device
548
- self.dtype = dtype
549
-
550
- @torch.no_grad()
551
- def fetch_codes(self, input_audios, additional_feats,layer,rvq_num=1):
552
- input_audio_0 = input_audios[[0],:]
553
- input_audio_1 = input_audios[[1],:]
554
- input_audio_0 = self.preprocess_audio(input_audio_0)
555
- input_audio_1 = self.preprocess_audio(input_audio_1)
556
-
557
- self.bestrq.eval()
558
-
559
- # bestrq_middle,bestrq_last = self.extract_bestrq_embeds(input_audios)
560
- # bestrq_middle = bestrq_middle.detach()
561
- # bestrq_last = bestrq_last.detach()
562
- bestrq_emb = self.extract_bestrq_embeds(input_audio_0,input_audio_1,layer)
563
- bestrq_emb = bestrq_emb.detach()
564
-
565
- # self.rvq_bestrq_middle.eval()
566
- # quantized_bestrq_middle, codes_bestrq_middle, *_ = self.rvq_bestrq_middle(bestrq_middle) # b,d,t
567
- # self.rvq_bestrq_last.eval()
568
- # quantized_bestrq_last, codes_bestrq_last, *_ = self.rvq_bestrq_last(bestrq_last) # b,d,t
569
-
570
- self.rvq_bestrq_emb.eval()
571
- quantized_bestrq_emb, codes_bestrq_emb, *_ = self.rvq_bestrq_emb(bestrq_emb)
572
- codes_bestrq_emb = codes_bestrq_emb[:,:rvq_num,:]
573
- # print("codes_bestrq_emb.shape:",codes_bestrq_emb.shape)
574
- # exit()
575
-
576
-
577
- if('spk' in additional_feats):
578
- self.xvecmodel.eval()
579
- spk_embeds = self.extract_spk_embeds(input_audios)
580
- else:
581
- spk_embeds = None
582
-
583
- # return [codes_prompt, codes_wav2vec], [prompt_embeds, wav2vec_embeds], spk_embeds
584
- # return [codes_prompt_7, codes_prompt_13, codes_prompt_20, codes_wav2vec_half, codes_wav2vec_last], [prompt_embeds_7, prompt_embeds_13, prompt_embeds_20, wav2vec_embeds_half, wav2vec_embeds_last], spk_embeds
585
- # return [codes_bestrq_middle, codes_bestrq_last], [bestrq_middle, bestrq_last], spk_embeds
586
- return [codes_bestrq_emb], [bestrq_emb], spk_embeds
587
- # return [codes_prompt_13, codes_wav2vec_last], [prompt_embeds_13, wav2vec_embeds_last], spk_embeds
588
-
589
- @torch.no_grad()
590
- def fetch_codes_batch(self, input_audios, additional_feats,layer,rvq_num=1):
591
- input_audio_0 = input_audios[:,0,:]
592
- input_audio_1 = input_audios[:,1,:]
593
- input_audio_0 = self.preprocess_audio(input_audio_0)
594
- input_audio_1 = self.preprocess_audio(input_audio_1)
595
-
596
- self.bestrq.eval()
597
-
598
- # bestrq_middle,bestrq_last = self.extract_bestrq_embeds(input_audios)
599
- # bestrq_middle = bestrq_middle.detach()
600
- # bestrq_last = bestrq_last.detach()
601
- bestrq_emb = self.extract_bestrq_embeds(input_audio_0,input_audio_1,layer)
602
- bestrq_emb = bestrq_emb.detach()
603
-
604
- # self.rvq_bestrq_middle.eval()
605
- # quantized_bestrq_middle, codes_bestrq_middle, *_ = self.rvq_bestrq_middle(bestrq_middle) # b,d,t
606
- # self.rvq_bestrq_last.eval()
607
- # quantized_bestrq_last, codes_bestrq_last, *_ = self.rvq_bestrq_last(bestrq_last) # b,d,t
608
-
609
- self.rvq_bestrq_emb.eval()
610
- quantized_bestrq_emb, codes_bestrq_emb, *_ = self.rvq_bestrq_emb(bestrq_emb)
611
- # print("codes_bestrq_emb.shape:",codes_bestrq_emb.shape)
612
- codes_bestrq_emb = codes_bestrq_emb[:,:rvq_num,:]
613
- # print("codes_bestrq_emb.shape:",codes_bestrq_emb.shape)
614
- # exit()
615
-
616
-
617
- if('spk' in additional_feats):
618
- self.xvecmodel.eval()
619
- spk_embeds = self.extract_spk_embeds(input_audios)
620
- else:
621
- spk_embeds = None
622
-
623
- # return [codes_prompt, codes_wav2vec], [prompt_embeds, wav2vec_embeds], spk_embeds
624
- # return [codes_prompt_7, codes_prompt_13, codes_prompt_20, codes_wav2vec_half, codes_wav2vec_last], [prompt_embeds_7, prompt_embeds_13, prompt_embeds_20, wav2vec_embeds_half, wav2vec_embeds_last], spk_embeds
625
- # return [codes_bestrq_middle, codes_bestrq_last], [bestrq_middle, bestrq_last], spk_embeds
626
- return [codes_bestrq_emb], [bestrq_emb], spk_embeds
627
- # return [codes_prompt_13, codes_wav2vec_last], [prompt_embeds_13, wav2vec_embeds_last], spk_embeds
628
-
629
- @torch.no_grad()
630
- def fetch_codes_batch_ds(self, input_audios, additional_feats, layer, rvq_num=1, ds=250):
631
- input_audio_0 = input_audios[:,0,:]
632
- input_audio_1 = input_audios[:,1,:]
633
- input_audio_0 = self.preprocess_audio(input_audio_0)
634
- input_audio_1 = self.preprocess_audio(input_audio_1)
635
-
636
- self.bestrq.eval()
637
-
638
- # bestrq_middle,bestrq_last = self.extract_bestrq_embeds(input_audios)
639
- # bestrq_middle = bestrq_middle.detach()
640
- # bestrq_last = bestrq_last.detach()
641
- bestrq_emb = self.extract_bestrq_embeds(input_audio_0,input_audio_1,layer)
642
- bestrq_emb = bestrq_emb.detach()
643
-
644
- # self.rvq_bestrq_middle.eval()
645
- # quantized_bestrq_middle, codes_bestrq_middle, *_ = self.rvq_bestrq_middle(bestrq_middle) # b,d,t
646
- # self.rvq_bestrq_last.eval()
647
- # quantized_bestrq_last, codes_bestrq_last, *_ = self.rvq_bestrq_last(bestrq_last) # b,d,t
648
-
649
- self.rvq_bestrq_emb.eval()
650
- bestrq_emb = torch.nn.functional.avg_pool1d(bestrq_emb, kernel_size=ds, stride=ds)
651
- quantized_bestrq_emb, codes_bestrq_emb, *_ = self.rvq_bestrq_emb(bestrq_emb)
652
- # print("codes_bestrq_emb.shape:",codes_bestrq_emb.shape)
653
- codes_bestrq_emb = codes_bestrq_emb[:,:rvq_num,:]
654
- # print("codes_bestrq_emb.shape:",codes_bestrq_emb.shape)
655
- # exit()
656
-
657
-
658
- if('spk' in additional_feats):
659
- self.xvecmodel.eval()
660
- spk_embeds = self.extract_spk_embeds(input_audios)
661
- else:
662
- spk_embeds = None
663
-
664
- # return [codes_prompt, codes_wav2vec], [prompt_embeds, wav2vec_embeds], spk_embeds
665
- # return [codes_prompt_7, codes_prompt_13, codes_prompt_20, codes_wav2vec_half, codes_wav2vec_last], [prompt_embeds_7, prompt_embeds_13, prompt_embeds_20, wav2vec_embeds_half, wav2vec_embeds_last], spk_embeds
666
- # return [codes_bestrq_middle, codes_bestrq_last], [bestrq_middle, bestrq_last], spk_embeds
667
- return [codes_bestrq_emb], [bestrq_emb], spk_embeds
668
- # return [codes_prompt_13, codes_wav2vec_last], [prompt_embeds_13, wav2vec_embeds_last], spk_embeds
669
-
670
- @torch.no_grad()
671
- def inference_codes(self, codes, spk_embeds, true_latents, latent_length, additional_feats, incontext_length=127,
672
- guidance_scale=2, num_steps=20,
673
- disable_progress=True, scenario='start_seg'):
674
- classifier_free_guidance = guidance_scale > 1.0
675
- device = self.device
676
- dtype = self.dtype
677
- # codes_bestrq_middle, codes_bestrq_last = codes
678
- codes_bestrq_emb = codes[0]
679
-
680
-
681
- batch_size = codes_bestrq_emb.shape[0]
682
-
683
-
684
- quantized_bestrq_emb,_,_=self.rvq_bestrq_emb.from_codes(codes_bestrq_emb)
685
- # quantized_bestrq_emb = torch.nn.functional.interpolate(quantized_bestrq_emb, size=(int(quantized_bestrq_emb.shape[-1]/999*937),), mode='linear', align_corners=True)
686
- quantized_bestrq_emb = quantized_bestrq_emb.permute(0,2,1).contiguous()
687
- print("quantized_bestrq_emb.shape:",quantized_bestrq_emb.shape)
688
- # quantized_bestrq_emb = torch.nn.functional.interpolate(quantized_bestrq_emb, size=(int(quantized_bestrq_emb.shape[-1]/999*937),), mode='linear', align_corners=True)
689
-
690
-
691
-
692
-
693
- if('spk' in additional_feats):
694
- spk_embeds = spk_embeds.repeat(1,1,quantized_bestrq_emb.shape[-2],1).detach()
695
-
696
- num_frames = quantized_bestrq_emb.shape[1]
697
-
698
- num_channels_latents = self.num_channels
699
- shape = (batch_size, num_frames, 64)
700
- latents = randn_tensor(shape, generator=None, device=device, dtype=dtype)
701
-
702
-
703
-
704
- latent_masks = torch.zeros(latents.shape[0], latents.shape[1], dtype=torch.int64, device=latents.device)
705
- latent_masks[:,0:latent_length] = 2
706
- if(scenario=='other_seg'):
707
- latent_masks[:,0:incontext_length] = 1
708
-
709
-
710
-
711
- quantized_bestrq_emb = (latent_masks > 0.5).unsqueeze(-1) * quantized_bestrq_emb \
712
- + (latent_masks < 0.5).unsqueeze(-1) * self.zero_cond_embedding1.reshape(1,1,1024)
713
- true_latents = true_latents.permute(0,2,1).contiguous()
714
- true_latents = self.normfeat.project_sample(true_latents)
715
- true_latents = true_latents.permute(0,2,1).contiguous()
716
- incontext_latents = true_latents * ((latent_masks > 0.5) * (latent_masks < 1.5)).unsqueeze(-1).float()
717
- incontext_length = ((latent_masks > 0.5) * (latent_masks < 1.5)).sum(-1)[0]
718
-
719
-
720
- attention_mask=(latent_masks > 0.5)
721
- B, L = attention_mask.size()
722
- attention_mask = attention_mask.view(B, 1, L)
723
- attention_mask = attention_mask * attention_mask.transpose(-1, -2)
724
- attention_mask = attention_mask.unsqueeze(1)
725
- latent_mask_input = self.mask_emb(latent_masks)
726
-
727
- if('spk' in additional_feats):
728
- # additional_model_input = torch.cat([quantized_bestrq_middle, quantized_bestrq_last, spk_embeds],1)
729
- additional_model_input = torch.cat([quantized_bestrq_emb, spk_embeds],1)
730
- else:
731
- # additional_model_input = torch.cat([quantized_bestrq_middle, quantized_bestrq_last],1)
732
- additional_model_input = torch.cat([quantized_bestrq_emb],1)
733
-
734
- temperature = 1.0
735
- t_span = torch.linspace(0, 1, num_steps + 1, device=quantized_bestrq_emb.device)
736
- latents = self.cfm_wrapper.solve_euler(latents * temperature, latent_mask_input,incontext_latents, incontext_length, t_span, additional_model_input,attention_mask, guidance_scale)
737
-
738
- latents[:,0:incontext_length,:] = incontext_latents[:,0:incontext_length,:]
739
- latents = latents.permute(0,2,1).contiguous()
740
- latents = self.normfeat.return_sample(latents)
741
- # latents = latents.permute(0,2,1).contiguous()
742
- return latents
743
-
744
- @torch.no_grad()
745
- def inference(self, input_audios, lyric, true_latents, latent_length, additional_feats, guidance_scale=2, num_steps=20,
746
- disable_progress=True,layer=5,scenario='start_seg',rvq_num=1):
747
- codes, embeds, spk_embeds = self.fetch_codes(input_audios, additional_feats,layer,rvq_num)
748
-
749
- latents = self.inference_codes(codes, spk_embeds, true_latents, latent_length, additional_feats, \
750
- guidance_scale=guidance_scale, num_steps=num_steps, \
751
- disable_progress=disable_progress,scenario=scenario)
752
- return latents
753
-
754
- @torch.no_grad()
755
- def inference_rtf(self, input_audios, lyric, true_latents, latent_length, additional_feats, guidance_scale=2, num_steps=20,
756
- disable_progress=True,layer=5,scenario='start_seg'):
757
- codes, embeds, spk_embeds = self.fetch_codes(input_audios, additional_feats,layer)
758
- import time
759
- start = time.time()
760
- latents = self.inference_codes(codes, spk_embeds, true_latents, latent_length, additional_feats, \
761
- guidance_scale=guidance_scale, num_steps=num_steps, \
762
- disable_progress=disable_progress,scenario=scenario)
763
- return latents,time.time()-start
764
-
765
- def prepare_latents(self, batch_size, num_frames, num_channels_latents, dtype, device):
766
- divisor = 4
767
- shape = (batch_size, num_channels_latents, num_frames, 32)
768
- if(num_frames%divisor>0):
769
- num_frames = round(num_frames/float(divisor))*divisor
770
- shape = (batch_size, num_channels_latents, num_frames, 32)
771
- latents = randn_tensor(shape, generator=None, device=device, dtype=dtype)
772
- return latents
773
-
774
-
 
1
+ import yaml
2
+ import random
3
+ import inspect
4
+ import numpy as np
5
+ from tqdm import tqdm
6
+ import typing as tp
7
+ from abc import ABC
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ import torchaudio
13
+
14
+ from einops import repeat
15
+ from tools.torch_tools import wav_to_fbank
16
+
17
+ import diffusers
18
+ from diffusers.utils.torch_utils import randn_tensor
19
+ from diffusers import DDPMScheduler
20
+ from models.transformer_2d_flow import Transformer2DModel
21
+ from transformers import AutoFeatureExtractor, Wav2Vec2BertModel,HubertModel
22
+ # from tools.get_mulan import get_mulan
23
+ from third_party.wespeaker.extract_embd import XVECModel
24
+ # from libs.rvq2 import RVQEmbedding
25
+ from libs.rvq.descript_quantize3_4layer_freezelayer1 import ResidualVectorQuantize
26
+
27
+ from models_gpt.models.gpt2_rope2_time_new_correct_mask_noncasual_reflow import GPT2Model
28
+ from models_gpt.models.gpt2_config import GPT2Config
29
+
30
+ from torch.cuda.amp import autocast
31
+
32
+
33
+ from our_MERT_BESTRQ.test import load_model
34
+
35
+ class HubertModelWithFinalProj(HubertModel):
36
+ def __init__(self, config):
37
+ super().__init__(config)
38
+
39
+ # The final projection layer is only used for backward compatibility.
40
+ # Following https://github.com/auspicious3000/contentvec/issues/6
41
+ # Remove this layer is necessary to achieve the desired outcome.
42
+ print("hidden_size:",config.hidden_size)
43
+ print("classifier_proj_size:",config.classifier_proj_size)
44
+ self.final_proj = nn.Linear(config.hidden_size, config.classifier_proj_size)
45
+
46
+
47
+ class SampleProcessor(torch.nn.Module):
48
+ def project_sample(self, x: torch.Tensor):
49
+ """Project the original sample to the 'space' where the diffusion will happen."""
50
+ """Project back from diffusion space to the actual sample space."""
51
+ return z
52
+
53
+ class Feature1DProcessor(SampleProcessor):
54
+ def __init__(self, dim: int = 100, power_std = 1., \
55
+ num_samples: int = 100_000, cal_num_frames: int = 600):
56
+ super().__init__()
57
+
58
+ self.num_samples = num_samples
59
+ self.dim = dim
60
+ self.power_std = power_std
61
+ self.cal_num_frames = cal_num_frames
62
+ self.register_buffer('counts', torch.zeros(1))
63
+ self.register_buffer('sum_x', torch.zeros(dim))
64
+ self.register_buffer('sum_x2', torch.zeros(dim))
65
+ self.register_buffer('sum_target_x2', torch.zeros(dim))
66
+ self.counts: torch.Tensor
67
+ self.sum_x: torch.Tensor
68
+ self.sum_x2: torch.Tensor
69
+
70
+ @property
71
+ def mean(self):
72
+ mean = self.sum_x / self.counts
73
+ if(self.counts < 10):
74
+ mean = torch.zeros_like(mean)
75
+ return mean
76
+
77
+ @property
78
+ def std(self):
79
+ std = (self.sum_x2 / self.counts - self.mean**2).clamp(min=0).sqrt()
80
+ if(self.counts < 10):
81
+ std = torch.ones_like(std)
82
+ return std
83
+
84
+ @property
85
+ def target_std(self):
86
+ return 1
87
+
88
+ def project_sample(self, x: torch.Tensor):
89
+ assert x.dim() == 3
90
+ if self.counts.item() < self.num_samples:
91
+ self.counts += len(x)
92
+ self.sum_x += x[:,:,0:self.cal_num_frames].mean(dim=(2,)).sum(dim=0)
93
+ self.sum_x2 += x[:,:,0:self.cal_num_frames].pow(2).mean(dim=(2,)).sum(dim=0)
94
+ rescale = (self.target_std / self.std.clamp(min=1e-12)) ** self.power_std # same output size
95
+ x = (x - self.mean.view(1, -1, 1)) * rescale.view(1, -1, 1)
96
+ return x
97
+
98
+ def return_sample(self, x: torch.Tensor):
99
+ assert x.dim() == 3
100
+ rescale = (self.std / self.target_std) ** self.power_std
101
+ # print(rescale, self.mean)
102
+ x = x * rescale.view(1, -1, 1) + self.mean.view(1, -1, 1)
103
+ return x
104
+
105
+ def pad_or_tunc_tolen(prior_text_encoder_hidden_states, prior_text_mask, prior_prompt_embeds, len_size=77):
106
+ if(prior_text_encoder_hidden_states.shape[1]<len_size):
107
+ prior_text_encoder_hidden_states = torch.cat([prior_text_encoder_hidden_states, \
108
+ torch.zeros(prior_text_mask.shape[0], len_size-prior_text_mask.shape[1], \
109
+ prior_text_encoder_hidden_states.shape[2], device=prior_text_mask.device, \
110
+ dtype=prior_text_encoder_hidden_states.dtype)],1)
111
+ prior_text_mask = torch.cat([prior_text_mask, torch.zeros(prior_text_mask.shape[0], len_size-prior_text_mask.shape[1], device=prior_text_mask.device, dtype=prior_text_mask.dtype)],1)
112
+ else:
113
+ prior_text_encoder_hidden_states = prior_text_encoder_hidden_states[:,0:len_size]
114
+ prior_text_mask = prior_text_mask[:,0:len_size]
115
+ prior_text_encoder_hidden_states = prior_text_encoder_hidden_states.permute(0,2,1).contiguous()
116
+ return prior_text_encoder_hidden_states, prior_text_mask, prior_prompt_embeds
117
+
118
+ class BASECFM(torch.nn.Module, ABC):
119
+ def __init__(
120
+ self,
121
+ estimator,
122
+ mlp,
123
+ ssl_layer
124
+ ):
125
+ super().__init__()
126
+ self.sigma_min = 1e-4
127
+
128
+ self.estimator = estimator
129
+ self.mlp = mlp
130
+ self.ssl_layer = ssl_layer
131
+
132
+ @torch.inference_mode()
133
+ def forward(self, mu, n_timesteps, temperature=1.0):
134
+ """Forward diffusion
135
+
136
+ Args:
137
+ mu (torch.Tensor): output of encoder
138
+ shape: (batch_size, n_channels, mel_timesteps, n_feats)
139
+ n_timesteps (int): number of diffusion steps
140
+ temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
141
+
142
+ Returns:
143
+ sample: generated mel-spectrogram
144
+ shape: (batch_size, n_channels, mel_timesteps, n_feats)
145
+ """
146
+ z = torch.randn_like(mu) * temperature
147
+ t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device)
148
+ return self.solve_euler(z, t_span=t_span)
149
+
150
+ def solve_euler(self, x, latent_mask_input,incontext_x, incontext_length, t_span, mu,attention_mask, guidance_scale):
151
+ """
152
+ Fixed euler solver for ODEs.
153
+ Args:
154
+ x (torch.Tensor): random noise
155
+ t_span (torch.Tensor): n_timesteps interpolated
156
+ shape: (n_timesteps + 1,)
157
+ mu (torch.Tensor): output of encoder
158
+ shape: (batch_size, n_channels, mel_timesteps, n_feats)
159
+ """
160
+ t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
161
+ noise = x.clone()
162
+
163
+ # I am storing this because I can later plot it by putting a debugger here and saving it to a file
164
+ # Or in future might add like a return_all_steps flag
165
+ sol = []
166
+
167
+ for step in tqdm(range(1, len(t_span))):
168
+ # print("incontext_x.shape:",incontext_x.shape)
169
+ # print("noise.shape:",noise.shape)
170
+ # print("t.shape:",t.shape)
171
+ x[:,0:incontext_length,:] = (1 - (1 - self.sigma_min) * t) * noise[:,0:incontext_length,:] + t * incontext_x[:,0:incontext_length,:]
172
+ if(guidance_scale > 1.0):
173
+
174
+ model_input = torch.cat([ \
175
+ torch.cat([latent_mask_input, latent_mask_input], 0), \
176
+ torch.cat([incontext_x, incontext_x], 0), \
177
+ torch.cat([torch.zeros_like(mu), mu], 0), \
178
+ torch.cat([x, x], 0), \
179
+ ], 2)
180
+ timestep=t.unsqueeze(-1).repeat(2)
181
+
182
+ dphi_dt = self.estimator(inputs_embeds=model_input, attention_mask=attention_mask,time_step=timestep).last_hidden_state
183
+ dphi_dt_uncond, dhpi_dt_cond = dphi_dt.chunk(2,0)
184
+ dphi_dt = dphi_dt_uncond + guidance_scale * (dhpi_dt_cond - dphi_dt_uncond)
185
+ else:
186
+ model_input = torch.cat([latent_mask_input, incontext_x, mu, x], 2)
187
+ timestep=t.unsqueeze(-1)
188
+ dphi_dt = self.estimator(inputs_embeds=model_input, attention_mask=attention_mask,time_step=timestep).last_hidden_state
189
+
190
+ dphi_dt = dphi_dt[: ,:, -x.shape[2]:]
191
+ # print("dphi_dt.shape:",dphi_dt.shape)
192
+ # print("x.shape:",x.shape)
193
+
194
+ x = x + dt * dphi_dt
195
+ t = t + dt
196
+ sol.append(x)
197
+ if step < len(t_span) - 1:
198
+ dt = t_span[step + 1] - t
199
+
200
+ return sol[-1]
201
+
202
+ def projection_loss(self,hidden_proj, bestrq_emb):
203
+ bsz = hidden_proj.shape[0]
204
+
205
+ hidden_proj_normalized = F.normalize(hidden_proj, dim=-1)
206
+ bestrq_emb_normalized = F.normalize(bestrq_emb, dim=-1)
207
+
208
+ proj_loss = -(hidden_proj_normalized * bestrq_emb_normalized).sum(dim=-1)
209
+ proj_loss = 1+proj_loss.mean()
210
+
211
+ return proj_loss
212
+
213
+ def compute_loss(self, x1, mu, latent_masks,attention_mask,wav2vec_embeds, validation_mode=False):
214
+ """Computes diffusion loss
215
+
216
+ Args:
217
+ x1 (torch.Tensor): Target
218
+ shape: (batch_size, n_channels, mel_timesteps, n_feats)
219
+ mu (torch.Tensor): output of encoder
220
+ shape: (batch_size, n_channels, mel_timesteps, n_feats)
221
+
222
+ Returns:
223
+ loss: conditional flow matching loss
224
+ y: conditional flow
225
+ shape: (batch_size, n_channels, mel_timesteps, n_feats)
226
+ """
227
+ b = mu[0].shape[0]
228
+ len_x = x1.shape[2]
229
+ # random timestep
230
+ if(validation_mode):
231
+ t = torch.ones([b, 1, 1], device=mu[0].device, dtype=mu[0].dtype) * 0.5
232
+ else:
233
+ t = torch.rand([b, 1, 1], device=mu[0].device, dtype=mu[0].dtype)
234
+ # sample noise p(x_0)
235
+ z = torch.randn_like(x1)
236
+
237
+ y = (1 - (1 - self.sigma_min) * t) * z + t * x1
238
+ u = x1 - (1 - self.sigma_min) * z
239
+ # print("y.shape:",y.shape)
240
+ #self.unet(inputs_embeds=model_input, attention_mask=attention_mask,encoder_hidden_states=text_embedding,encoder_attention_mask=txt_attn_mask,time_step=timesteps).last_hidden_state
241
+ model_input = torch.cat([*mu,y], 2)
242
+ t=t.squeeze(-1).squeeze(-1)
243
+ # print("model_input.shape:",model_input.shape)
244
+ # print("attention_mask.shape:",attention_mask.shape)
245
+ out = self.estimator(inputs_embeds=model_input, attention_mask=attention_mask,time_step=t,output_hidden_states=True)
246
+ hidden_layer = out.hidden_states[self.ssl_layer]
247
+ hidden_proj = self.mlp(hidden_layer)
248
+ # print("hidden_proj.shape:",hidden_proj.shape)
249
+ # print("mert_emb.shape:",mert_emb.shape)
250
+ # exit()
251
+
252
+
253
+ out = out.last_hidden_state
254
+
255
+ out=out[:,:,-len_x:]
256
+ # out=self.proj_out(out)
257
+
258
+ weight = (latent_masks > 1.5).unsqueeze(-1).repeat(1, 1, out.shape[-1]).float() + (latent_masks < 0.5).unsqueeze(-1).repeat(1, 1, out.shape[-1]).float() * 0.01
259
+ # print("out.shape",out.shape)
260
+ # print("u.shape",u.shape)
261
+ loss_re = F.mse_loss(out * weight, u * weight, reduction="sum") / weight.sum()
262
+ # print("hidden_proj.shape:",hidden_proj.shape)
263
+ # print("wav2vec_embeds.shape:",wav2vec_embeds.shape)
264
+ loss_cos = self.projection_loss(hidden_proj, wav2vec_embeds)
265
+ loss = loss_re + loss_cos * 0.5
266
+ # print("loss_cos:",loss_cos,loss_cos.device)
267
+ print("loss:",loss,loss.device)
268
+ # exit()
269
+ return loss, loss_re, loss_cos
270
+
271
+ class PromptCondAudioDiffusion(nn.Module):
272
+ def __init__(
273
+ self,
274
+ num_channels,
275
+ unet_model_name=None,
276
+ unet_model_config_path=None,
277
+ snr_gamma=None,
278
+ hubert_layer=None,
279
+ ssl_layer=None,
280
+ uncondition=True,
281
+ out_paint=False,
282
+ ):
283
+ super().__init__()
284
+
285
+ assert unet_model_name is not None or unet_model_config_path is not None, "Either UNet pretrain model name or a config file path is required"
286
+
287
+ self.unet_model_name = unet_model_name
288
+ self.unet_model_config_path = unet_model_config_path
289
+ self.snr_gamma = snr_gamma
290
+ self.uncondition = uncondition
291
+ self.num_channels = num_channels
292
+ self.hubert_layer = hubert_layer
293
+ self.ssl_layer = ssl_layer
294
+
295
+ # https://huggingface.co/docs/diffusers/v0.14.0/en/api/schedulers/overview
296
+ self.normfeat = Feature1DProcessor(dim=64)
297
+
298
+ self.sample_rate = 48000
299
+ self.num_samples_perseg = self.sample_rate * 20 // 1000
300
+ self.rsp48toclap = torchaudio.transforms.Resample(48000, 24000)
301
+ self.rsq48towav2vec = torchaudio.transforms.Resample(48000, 16000)
302
+ # self.wav2vec = Wav2Vec2BertModel.from_pretrained("facebook/w2v-bert-2.0", trust_remote_code=True)
303
+ # self.wav2vec_processor = AutoFeatureExtractor.from_pretrained("facebook/w2v-bert-2.0", trust_remote_code=True)
304
+ self.bestrq = load_model(
305
+ model_dir='path/to/our-MERT/mert_fairseq',
306
+ checkpoint_dir='checkpoint-120000.pt',
307
+ )
308
+ self.rsq48tobestrq = torchaudio.transforms.Resample(48000, 24000)
309
+ self.rsq48tohubert = torchaudio.transforms.Resample(48000, 16000)
310
+ for v in self.bestrq.parameters():v.requires_grad = False
311
+ self.rvq_bestrq_emb = ResidualVectorQuantize(input_dim = 1024, n_codebooks = 2, codebook_size = 16_384, codebook_dim = 32, quantizer_dropout = 0.0, stale_tolerance=200)
312
+ # for v in self.rvq_bestrq_emb.parameters():
313
+ # print(v)
314
+ freeze_parameters='quantizers.0'
315
+ for name, param in self.rvq_bestrq_emb.named_parameters():
316
+ if freeze_parameters in name:
317
+ param.requires_grad = False
318
+ print("Freezing RVQ parameters:", name)
319
+ self.hubert = HubertModelWithFinalProj.from_pretrained("huggingface_cache/models--lengyue233--content-vec-best/snapshots/c0b9ba13db21beaa4053faae94c102ebe326fd68")
320
+ for v in self.hubert.parameters():v.requires_grad = False
321
+ self.zero_cond_embedding1 = nn.Parameter(torch.randn(32*32,))
322
+ # self.xvecmodel = XVECModel()
323
+ config = GPT2Config(n_positions=1000,n_layer=39,n_head=30,n_embd=1200)
324
+ unet = GPT2Model(config)
325
+ mlp = nn.Sequential(
326
+ nn.Linear(1200, 1024),
327
+ nn.SiLU(),
328
+ nn.Linear(1024, 1024),
329
+ nn.SiLU(),
330
+ nn.Linear(1024, 768)
331
+ )
332
+ self.set_from = "random"
333
+ self.cfm_wrapper = BASECFM(unet, mlp,self.ssl_layer)
334
+ self.mask_emb = torch.nn.Embedding(3, 48)
335
+ print("Transformer initialized from pretrain.")
336
+ torch.cuda.empty_cache()
337
+ # self.unet.set_attn_processor(AttnProcessor2_0())
338
+ # self.unet.set_use_memory_efficient_attention_xformers(True)
339
+
340
+ # self.start_embedding = nn.Parameter(torch.randn(1,1024))
341
+ # self.end_embedding = nn.Parameter(torch.randn(1,1024))
342
+
343
+ def compute_snr(self, timesteps):
344
+ """
345
+ Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
346
+ """
347
+ alphas_cumprod = self.noise_scheduler.alphas_cumprod
348
+ sqrt_alphas_cumprod = alphas_cumprod**0.5
349
+ sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
350
+
351
+ # Expand the tensors.
352
+ # Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026
353
+ sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
354
+ while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape):
355
+ sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
356
+ alpha = sqrt_alphas_cumprod.expand(timesteps.shape)
357
+
358
+ sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
359
+ while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape):
360
+ sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None]
361
+ sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape)
362
+
363
+ # Compute SNR.
364
+ snr = (alpha / sigma) ** 2
365
+ return snr
366
+
367
+ def preprocess_audio(self, input_audios, threshold=0.9):
368
+ assert len(input_audios.shape) == 2, input_audios.shape
369
+ norm_value = torch.ones_like(input_audios[:,0])
370
+ max_volume = input_audios.abs().max(dim=-1)[0]
371
+ norm_value[max_volume>threshold] = max_volume[max_volume>threshold] / threshold
372
+ return input_audios/norm_value.unsqueeze(-1)
373
+
374
+ def extract_wav2vec_embeds(self, input_audios,output_len):
375
+ wav2vec_stride = 2
376
+
377
+ wav2vec_embeds = self.hubert(self.rsq48tohubert(input_audios), output_hidden_states=True).hidden_states # 1, 4096, 1024
378
+ # print(wav2vec_embeds)
379
+ # print("audio.shape:",input_audios.shape)
380
+ wav2vec_embeds_last=wav2vec_embeds[self.hubert_layer]
381
+ # print("wav2vec_embeds_last.shape:",wav2vec_embeds_last.shape)
382
+ wav2vec_embeds_last=torch.nn.functional.interpolate(wav2vec_embeds_last.permute(0, 2, 1), size=output_len, mode='linear', align_corners=False).permute(0, 2, 1)
383
+ return wav2vec_embeds_last
384
+
385
+ def extract_mert_embeds(self, input_audios):
386
+ prompt_stride = 3
387
+ inputs = self.clap_embd_extractor.mulan.audio.processor(self.rsp48toclap(input_audios), sampling_rate=self.clap_embd_extractor.mulan.audio.sr, return_tensors="pt")
388
+ input_values = inputs['input_values'].squeeze(0).to(input_audios.device, dtype = input_audios.dtype)
389
+ prompt_embeds = self.clap_embd_extractor.mulan.audio.model(input_values, output_hidden_states=True).hidden_states # batch_size, Time steps, 1024
390
+ mert_emb= prompt_embeds[-1]
391
+ mert_emb = torch.nn.functional.interpolate(mert_emb.permute(0, 2, 1), size=500, mode='linear', align_corners=False).permute(0, 2, 1)
392
+
393
+ return mert_emb
394
+
395
+ def extract_bestrq_embeds(self, input_audio_0,input_audio_1,layer):
396
+ self.bestrq.eval()
397
+ # print("audio shape:",input_audio_0.shape)
398
+ input_wav_mean = (input_audio_0 + input_audio_1) / 2.0
399
+ # print("input_wav_mean.shape:",input_wav_mean.shape)
400
+ # input_wav_mean = torch.randn(2,1720320*2).to(input_audio_0.device)
401
+ input_wav_mean = self.bestrq(self.rsq48tobestrq(input_wav_mean), features_only = True)
402
+ layer_results = input_wav_mean['layer_results']
403
+ # print("layer_results.shape:",layer_results[layer].shape)
404
+ bestrq_emb = layer_results[layer]
405
+ bestrq_emb = bestrq_emb.permute(0,2,1).contiguous()
406
+ #[b,t,1024] t=t/960
407
+ #35.84s->batch,896,1024
408
+ return bestrq_emb
409
+
410
+
411
+ def extract_spk_embeds(self, input_audios):
412
+ spk_embeds = self.xvecmodel(self.rsq48towav2vec(input_audios))
413
+ spk_embeds = self.spk_linear(spk_embeds).reshape(spk_embeds.shape[0], 16, 1, 32)
414
+ return spk_embeds
415
+
416
+ def extract_lyric_feats(self, lyric):
417
+ with torch.no_grad():
418
+ try:
419
+ text_encoder_hidden_states, text_mask, text_prompt_embeds = self.clap_embd_extractor(texts = lyric, return_one=False)
420
+ except:
421
+ text_encoder_hidden_states, text_mask, text_prompt_embeds = self.clap_embd_extractor(texts = [""] * len(lyric), return_one=False)
422
+ text_encoder_hidden_states = text_encoder_hidden_states.to(self.device)
423
+ text_mask = text_mask.to(self.device)
424
+ text_encoder_hidden_states, text_mask, text_prompt_embeds = \
425
+ pad_or_tunc_tolen(text_encoder_hidden_states, text_mask, text_prompt_embeds)
426
+ text_encoder_hidden_states = text_encoder_hidden_states.permute(0,2,1).contiguous()
427
+ return text_encoder_hidden_states, text_mask
428
+
429
+ def extract_energy_bar(self, input_audios):
430
+ if(input_audios.shape[-1] % self.num_samples_perseg > 0):
431
+ energy_bar = input_audios[:,:-1 * (input_audios.shape[-1] % self.num_samples_perseg)].reshape(input_audios.shape[0],-1,self.num_samples_perseg)
432
+ else:
433
+ energy_bar = input_audios.reshape(input_audios.shape[0],-1,self.num_samples_perseg)
434
+ energy_bar = (energy_bar.pow(2.0).mean(-1).sqrt() + 1e-6).log10() * 20 # B T
435
+ energy_bar = (energy_bar / 2.0 + 16).clamp(0,16).int()
436
+ energy_embedding = self.energy_embedding(energy_bar)
437
+ energy_embedding = energy_embedding.view(energy_embedding.shape[0], energy_embedding.shape[1] // 2, 2, 32).reshape(energy_embedding.shape[0], energy_embedding.shape[1] // 2, 64).permute(0,2,1) # b 128 t
438
+ return energy_embedding
439
+
440
+ def forward(self, input_audios, lyric, latents, latent_masks, validation_mode=False, \
441
+ additional_feats = ['spk', 'lyric'], \
442
+ train_rvq=True, train_ssl=False,layer=5):
443
+ if not hasattr(self,"device"):
444
+ self.device = input_audios.device
445
+ if not hasattr(self,"dtype"):
446
+ self.dtype = input_audios.dtype
447
+ device = self.device
448
+ input_audio_0 = input_audios[:,0,:]
449
+ input_audio_1 = input_audios[:,1,:]
450
+ input_audio_0 = self.preprocess_audio(input_audio_0)
451
+ input_audio_1 = self.preprocess_audio(input_audio_1)
452
+ input_audios_wav2vec = (input_audio_0 + input_audio_1) / 2.0
453
+ # energy_embedding = self.extract_energy_bar(input_audios)
454
+ # print("energy_embedding.shape:",energy_embedding.shape)
455
+ # with autocast(enabled=False):
456
+ if(train_ssl):
457
+ self.wav2vec.train()
458
+ wav2vec_embeds = self.extract_wav2vec_embeds(input_audios)
459
+ self.clap_embd_extractor.train()
460
+ prompt_embeds = self.extract_mert_embeds(input_audios)
461
+ if('spk' in additional_feats):
462
+ self.xvecmodel.train()
463
+ spk_embeds = self.extract_spk_embeds(input_audios).repeat(1,1,prompt_embeds.shape[-1]//2,1)
464
+ else:
465
+ with torch.no_grad():
466
+ with autocast(enabled=False):
467
+ bestrq_emb = self.extract_bestrq_embeds(input_audio_0,input_audio_1,layer)
468
+ # mert_emb = self.extract_mert_embeds(input_audios_mert)
469
+
470
+ wav2vec_embeds = self.extract_wav2vec_embeds(input_audios_wav2vec,bestrq_emb.shape[2])
471
+
472
+ bestrq_emb = bestrq_emb.detach()
473
+ if('lyric' in additional_feats):
474
+ text_encoder_hidden_states, text_mask = self.extract_lyric_feats(lyric)
475
+ else:
476
+ text_encoder_hidden_states, text_mask = None, None
477
+
478
+
479
+ if(train_rvq):
480
+ random_num=random.random()
481
+ if(random_num<0.6):
482
+ rvq_layer = 1
483
+ elif(random_num<0.8):
484
+ rvq_layer = 2
485
+ else:
486
+ rvq_layer = 4
487
+ quantized_bestrq_emb, _, _, commitment_loss_bestrq_emb, codebook_loss_bestrq_emb,_ = self.rvq_bestrq_emb(bestrq_emb,n_quantizers=rvq_layer) # b,d,t
488
+ else:
489
+ bestrq_emb = bestrq_emb.float()
490
+ self.rvq_bestrq_emb.eval()
491
+ # with autocast(enabled=False):
492
+ quantized_bestrq_emb, _, _, commitment_loss_bestrq_emb, codebook_loss_bestrq_emb,_ = self.rvq_bestrq_emb(bestrq_emb) # b,d,t
493
+ commitment_loss_bestrq_emb = commitment_loss_bestrq_emb.detach()
494
+ codebook_loss_bestrq_emb = codebook_loss_bestrq_emb.detach()
495
+ quantized_bestrq_emb = quantized_bestrq_emb.detach()
496
+
497
+ commitment_loss = commitment_loss_bestrq_emb
498
+ codebook_loss = codebook_loss_bestrq_emb
499
+
500
+
501
+ alpha=1
502
+ quantized_bestrq_emb = quantized_bestrq_emb * alpha + bestrq_emb * (1-alpha)
503
+
504
+ # print("quantized_bestrq_emb.shape:",quantized_bestrq_emb.shape)
505
+ # print("latent_masks.shape:",latent_masks.shape)
506
+ # quantized_bestrq_emb = torch.nn.functional.interpolate(quantized_bestrq_emb, size=(int(quantized_bestrq_emb.shape[-1]/999*937),), mode='linear', align_corners=True)
507
+
508
+
509
+
510
+ scenario = np.random.choice(['start_seg', 'other_seg'])
511
+ if(scenario == 'other_seg'):
512
+ for binx in range(input_audios.shape[0]):
513
+ # latent_masks[binx,0:64] = 1
514
+ latent_masks[binx,0:random.randint(64,128)] = 1
515
+ quantized_bestrq_emb = quantized_bestrq_emb.permute(0,2,1).contiguous()
516
+ # print("quantized_bestrq_emb.shape:",quantized_bestrq_emb.shape)
517
+ # print("quantized_bestrq_emb1.shape:",quantized_bestrq_emb.shape)
518
+ # print("latent_masks.shape:",latent_masks.shape)
519
+ quantized_bestrq_emb = (latent_masks > 0.5).unsqueeze(-1) * quantized_bestrq_emb \
520
+ + (latent_masks < 0.5).unsqueeze(-1) * self.zero_cond_embedding1.reshape(1,1,1024)
521
+
522
+
523
+
524
+
525
+ if self.uncondition:
526
+ mask_indices = [k for k in range(quantized_bestrq_emb.shape[0]) if random.random() < 0.1]
527
+ if len(mask_indices) > 0:
528
+ quantized_bestrq_emb[mask_indices] = 0
529
+ # print("latents.shape:",latents.shape)
530
+ latents = latents.permute(0,2,1).contiguous()
531
+ latents = self.normfeat.project_sample(latents)
532
+ latents = latents.permute(0,2,1).contiguous()
533
+ incontext_latents = latents * ((latent_masks > 0.5) * (latent_masks < 1.5)).unsqueeze(-1).float()
534
+ attention_mask=(latent_masks > 0.5)
535
+ B, L = attention_mask.size()
536
+ attention_mask = attention_mask.view(B, 1, L)
537
+ attention_mask = attention_mask * attention_mask.transpose(-1, -2)
538
+ attention_mask = attention_mask.unsqueeze(1)
539
+ # print("incontext_latents.shape:",incontext_latents.shape)
540
+ # print("quantized_bestrq_emb.shape:",quantized_bestrq_emb.shape)
541
+ latent_mask_input = self.mask_emb(latent_masks)
542
+ #64+48+64+1024
543
+ loss,loss_re, loss_cos = self.cfm_wrapper.compute_loss(latents, [latent_mask_input,incontext_latents, quantized_bestrq_emb], latent_masks,attention_mask,wav2vec_embeds, validation_mode=validation_mode)
544
+ return loss,loss_re, loss_cos, commitment_loss.mean(), codebook_loss.mean()
545
+
546
+ def init_device_dtype(self, device, dtype):
547
+ self.device = device
548
+ self.dtype = dtype
549
+
550
+ @torch.no_grad()
551
+ def fetch_codes(self, input_audios, additional_feats,layer,rvq_num=1):
552
+ input_audio_0 = input_audios[[0],:]
553
+ input_audio_1 = input_audios[[1],:]
554
+ input_audio_0 = self.preprocess_audio(input_audio_0)
555
+ input_audio_1 = self.preprocess_audio(input_audio_1)
556
+
557
+ self.bestrq.eval()
558
+
559
+ # bestrq_middle,bestrq_last = self.extract_bestrq_embeds(input_audios)
560
+ # bestrq_middle = bestrq_middle.detach()
561
+ # bestrq_last = bestrq_last.detach()
562
+ bestrq_emb = self.extract_bestrq_embeds(input_audio_0,input_audio_1,layer)
563
+ bestrq_emb = bestrq_emb.detach()
564
+
565
+ # self.rvq_bestrq_middle.eval()
566
+ # quantized_bestrq_middle, codes_bestrq_middle, *_ = self.rvq_bestrq_middle(bestrq_middle) # b,d,t
567
+ # self.rvq_bestrq_last.eval()
568
+ # quantized_bestrq_last, codes_bestrq_last, *_ = self.rvq_bestrq_last(bestrq_last) # b,d,t
569
+
570
+ self.rvq_bestrq_emb.eval()
571
+ quantized_bestrq_emb, codes_bestrq_emb, *_ = self.rvq_bestrq_emb(bestrq_emb)
572
+ codes_bestrq_emb = codes_bestrq_emb[:,:rvq_num,:]
573
+ # print("codes_bestrq_emb.shape:",codes_bestrq_emb.shape)
574
+ # exit()
575
+
576
+
577
+ if('spk' in additional_feats):
578
+ self.xvecmodel.eval()
579
+ spk_embeds = self.extract_spk_embeds(input_audios)
580
+ else:
581
+ spk_embeds = None
582
+
583
+ # return [codes_prompt, codes_wav2vec], [prompt_embeds, wav2vec_embeds], spk_embeds
584
+ # return [codes_prompt_7, codes_prompt_13, codes_prompt_20, codes_wav2vec_half, codes_wav2vec_last], [prompt_embeds_7, prompt_embeds_13, prompt_embeds_20, wav2vec_embeds_half, wav2vec_embeds_last], spk_embeds
585
+ # return [codes_bestrq_middle, codes_bestrq_last], [bestrq_middle, bestrq_last], spk_embeds
586
+ return [codes_bestrq_emb], [bestrq_emb], spk_embeds
587
+ # return [codes_prompt_13, codes_wav2vec_last], [prompt_embeds_13, wav2vec_embeds_last], spk_embeds
588
+
589
+ @torch.no_grad()
590
+ def fetch_codes_batch(self, input_audios, additional_feats,layer,rvq_num=1):
591
+ input_audio_0 = input_audios[:,0,:]
592
+ input_audio_1 = input_audios[:,1,:]
593
+ input_audio_0 = self.preprocess_audio(input_audio_0)
594
+ input_audio_1 = self.preprocess_audio(input_audio_1)
595
+
596
+ self.bestrq.eval()
597
+
598
+ # bestrq_middle,bestrq_last = self.extract_bestrq_embeds(input_audios)
599
+ # bestrq_middle = bestrq_middle.detach()
600
+ # bestrq_last = bestrq_last.detach()
601
+ bestrq_emb = self.extract_bestrq_embeds(input_audio_0,input_audio_1,layer)
602
+ bestrq_emb = bestrq_emb.detach()
603
+
604
+ # self.rvq_bestrq_middle.eval()
605
+ # quantized_bestrq_middle, codes_bestrq_middle, *_ = self.rvq_bestrq_middle(bestrq_middle) # b,d,t
606
+ # self.rvq_bestrq_last.eval()
607
+ # quantized_bestrq_last, codes_bestrq_last, *_ = self.rvq_bestrq_last(bestrq_last) # b,d,t
608
+
609
+ self.rvq_bestrq_emb.eval()
610
+ quantized_bestrq_emb, codes_bestrq_emb, *_ = self.rvq_bestrq_emb(bestrq_emb)
611
+ # print("codes_bestrq_emb.shape:",codes_bestrq_emb.shape)
612
+ codes_bestrq_emb = codes_bestrq_emb[:,:rvq_num,:]
613
+ # print("codes_bestrq_emb.shape:",codes_bestrq_emb.shape)
614
+ # exit()
615
+
616
+
617
+ if('spk' in additional_feats):
618
+ self.xvecmodel.eval()
619
+ spk_embeds = self.extract_spk_embeds(input_audios)
620
+ else:
621
+ spk_embeds = None
622
+
623
+ # return [codes_prompt, codes_wav2vec], [prompt_embeds, wav2vec_embeds], spk_embeds
624
+ # return [codes_prompt_7, codes_prompt_13, codes_prompt_20, codes_wav2vec_half, codes_wav2vec_last], [prompt_embeds_7, prompt_embeds_13, prompt_embeds_20, wav2vec_embeds_half, wav2vec_embeds_last], spk_embeds
625
+ # return [codes_bestrq_middle, codes_bestrq_last], [bestrq_middle, bestrq_last], spk_embeds
626
+ return [codes_bestrq_emb], [bestrq_emb], spk_embeds
627
+ # return [codes_prompt_13, codes_wav2vec_last], [prompt_embeds_13, wav2vec_embeds_last], spk_embeds
628
+
629
+ @torch.no_grad()
630
+ def fetch_codes_batch_ds(self, input_audios, additional_feats, layer, rvq_num=1, ds=250):
631
+ input_audio_0 = input_audios[:,0,:]
632
+ input_audio_1 = input_audios[:,1,:]
633
+ input_audio_0 = self.preprocess_audio(input_audio_0)
634
+ input_audio_1 = self.preprocess_audio(input_audio_1)
635
+
636
+ self.bestrq.eval()
637
+
638
+ # bestrq_middle,bestrq_last = self.extract_bestrq_embeds(input_audios)
639
+ # bestrq_middle = bestrq_middle.detach()
640
+ # bestrq_last = bestrq_last.detach()
641
+ bestrq_emb = self.extract_bestrq_embeds(input_audio_0,input_audio_1,layer)
642
+ bestrq_emb = bestrq_emb.detach()
643
+
644
+ # self.rvq_bestrq_middle.eval()
645
+ # quantized_bestrq_middle, codes_bestrq_middle, *_ = self.rvq_bestrq_middle(bestrq_middle) # b,d,t
646
+ # self.rvq_bestrq_last.eval()
647
+ # quantized_bestrq_last, codes_bestrq_last, *_ = self.rvq_bestrq_last(bestrq_last) # b,d,t
648
+
649
+ self.rvq_bestrq_emb.eval()
650
+ bestrq_emb = torch.nn.functional.avg_pool1d(bestrq_emb, kernel_size=ds, stride=ds)
651
+ quantized_bestrq_emb, codes_bestrq_emb, *_ = self.rvq_bestrq_emb(bestrq_emb)
652
+ # print("codes_bestrq_emb.shape:",codes_bestrq_emb.shape)
653
+ codes_bestrq_emb = codes_bestrq_emb[:,:rvq_num,:]
654
+ # print("codes_bestrq_emb.shape:",codes_bestrq_emb.shape)
655
+ # exit()
656
+
657
+
658
+ if('spk' in additional_feats):
659
+ self.xvecmodel.eval()
660
+ spk_embeds = self.extract_spk_embeds(input_audios)
661
+ else:
662
+ spk_embeds = None
663
+
664
+ # return [codes_prompt, codes_wav2vec], [prompt_embeds, wav2vec_embeds], spk_embeds
665
+ # return [codes_prompt_7, codes_prompt_13, codes_prompt_20, codes_wav2vec_half, codes_wav2vec_last], [prompt_embeds_7, prompt_embeds_13, prompt_embeds_20, wav2vec_embeds_half, wav2vec_embeds_last], spk_embeds
666
+ # return [codes_bestrq_middle, codes_bestrq_last], [bestrq_middle, bestrq_last], spk_embeds
667
+ return [codes_bestrq_emb], [bestrq_emb], spk_embeds
668
+ # return [codes_prompt_13, codes_wav2vec_last], [prompt_embeds_13, wav2vec_embeds_last], spk_embeds
669
+
670
+ @torch.no_grad()
671
+ def inference_codes(self, codes, spk_embeds, true_latents, latent_length, additional_feats, incontext_length=127,
672
+ guidance_scale=2, num_steps=20,
673
+ disable_progress=True, scenario='start_seg'):
674
+ classifier_free_guidance = guidance_scale > 1.0
675
+ device = self.device
676
+ dtype = self.dtype
677
+ # codes_bestrq_middle, codes_bestrq_last = codes
678
+ codes_bestrq_emb = codes[0]
679
+
680
+
681
+ batch_size = codes_bestrq_emb.shape[0]
682
+
683
+
684
+ quantized_bestrq_emb,_,_=self.rvq_bestrq_emb.from_codes(codes_bestrq_emb)
685
+ # quantized_bestrq_emb = torch.nn.functional.interpolate(quantized_bestrq_emb, size=(int(quantized_bestrq_emb.shape[-1]/999*937),), mode='linear', align_corners=True)
686
+ quantized_bestrq_emb = quantized_bestrq_emb.permute(0,2,1).contiguous()
687
+ print("quantized_bestrq_emb.shape:",quantized_bestrq_emb.shape)
688
+ # quantized_bestrq_emb = torch.nn.functional.interpolate(quantized_bestrq_emb, size=(int(quantized_bestrq_emb.shape[-1]/999*937),), mode='linear', align_corners=True)
689
+
690
+
691
+
692
+
693
+ if('spk' in additional_feats):
694
+ spk_embeds = spk_embeds.repeat(1,1,quantized_bestrq_emb.shape[-2],1).detach()
695
+
696
+ num_frames = quantized_bestrq_emb.shape[1]
697
+
698
+ num_channels_latents = self.num_channels
699
+ shape = (batch_size, num_frames, 64)
700
+ latents = randn_tensor(shape, generator=None, device=device, dtype=dtype)
701
+
702
+
703
+
704
+ latent_masks = torch.zeros(latents.shape[0], latents.shape[1], dtype=torch.int64, device=latents.device)
705
+ latent_masks[:,0:latent_length] = 2
706
+ if(scenario=='other_seg'):
707
+ latent_masks[:,0:incontext_length] = 1
708
+
709
+
710
+
711
+ quantized_bestrq_emb = (latent_masks > 0.5).unsqueeze(-1) * quantized_bestrq_emb \
712
+ + (latent_masks < 0.5).unsqueeze(-1) * self.zero_cond_embedding1.reshape(1,1,1024)
713
+ true_latents = true_latents.permute(0,2,1).contiguous()
714
+ true_latents = self.normfeat.project_sample(true_latents)
715
+ true_latents = true_latents.permute(0,2,1).contiguous()
716
+ incontext_latents = true_latents * ((latent_masks > 0.5) * (latent_masks < 1.5)).unsqueeze(-1).float()
717
+ incontext_length = ((latent_masks > 0.5) * (latent_masks < 1.5)).sum(-1)[0]
718
+
719
+
720
+ attention_mask=(latent_masks > 0.5)
721
+ B, L = attention_mask.size()
722
+ attention_mask = attention_mask.view(B, 1, L)
723
+ attention_mask = attention_mask * attention_mask.transpose(-1, -2)
724
+ attention_mask = attention_mask.unsqueeze(1)
725
+ latent_mask_input = self.mask_emb(latent_masks)
726
+
727
+ if('spk' in additional_feats):
728
+ # additional_model_input = torch.cat([quantized_bestrq_middle, quantized_bestrq_last, spk_embeds],1)
729
+ additional_model_input = torch.cat([quantized_bestrq_emb, spk_embeds],1)
730
+ else:
731
+ # additional_model_input = torch.cat([quantized_bestrq_middle, quantized_bestrq_last],1)
732
+ additional_model_input = torch.cat([quantized_bestrq_emb],1)
733
+
734
+ temperature = 1.0
735
+ t_span = torch.linspace(0, 1, num_steps + 1, device=quantized_bestrq_emb.device)
736
+ latents = self.cfm_wrapper.solve_euler(latents * temperature, latent_mask_input,incontext_latents, incontext_length, t_span, additional_model_input,attention_mask, guidance_scale)
737
+
738
+ latents[:,0:incontext_length,:] = incontext_latents[:,0:incontext_length,:]
739
+ latents = latents.permute(0,2,1).contiguous()
740
+ latents = self.normfeat.return_sample(latents)
741
+ # latents = latents.permute(0,2,1).contiguous()
742
+ return latents
743
+
744
+ @torch.no_grad()
745
+ def inference(self, input_audios, lyric, true_latents, latent_length, additional_feats, guidance_scale=2, num_steps=20,
746
+ disable_progress=True,layer=5,scenario='start_seg',rvq_num=1):
747
+ codes, embeds, spk_embeds = self.fetch_codes(input_audios, additional_feats,layer,rvq_num)
748
+
749
+ latents = self.inference_codes(codes, spk_embeds, true_latents, latent_length, additional_feats, \
750
+ guidance_scale=guidance_scale, num_steps=num_steps, \
751
+ disable_progress=disable_progress,scenario=scenario)
752
+ return latents
753
+
754
+ @torch.no_grad()
755
+ def inference_rtf(self, input_audios, lyric, true_latents, latent_length, additional_feats, guidance_scale=2, num_steps=20,
756
+ disable_progress=True,layer=5,scenario='start_seg'):
757
+ codes, embeds, spk_embeds = self.fetch_codes(input_audios, additional_feats,layer)
758
+ import time
759
+ start = time.time()
760
+ latents = self.inference_codes(codes, spk_embeds, true_latents, latent_length, additional_feats, \
761
+ guidance_scale=guidance_scale, num_steps=num_steps, \
762
+ disable_progress=disable_progress,scenario=scenario)
763
+ return latents,time.time()-start
764
+
765
+ def prepare_latents(self, batch_size, num_frames, num_channels_latents, dtype, device):
766
+ divisor = 4
767
+ shape = (batch_size, num_channels_latents, num_frames, 32)
768
+ if(num_frames%divisor>0):
769
+ num_frames = round(num_frames/float(divisor))*divisor
770
+ shape = (batch_size, num_channels_latents, num_frames, 32)
771
+ latents = randn_tensor(shape, generator=None, device=device, dtype=dtype)
772
+ return latents
773
+
774
+
codeclm/tokenizer/Flow1dVAE/model_4rvq.py CHANGED
@@ -1,774 +1,774 @@
1
- import yaml
2
- import random
3
- import inspect
4
- import numpy as np
5
- from tqdm import tqdm
6
- import typing as tp
7
- from abc import ABC
8
-
9
- import torch
10
- import torch.nn as nn
11
- import torch.nn.functional as F
12
- import torchaudio
13
-
14
- from einops import repeat
15
- from tools.torch_tools import wav_to_fbank
16
-
17
- import diffusers
18
- from diffusers.utils.torch_utils import randn_tensor
19
- from diffusers import DDPMScheduler
20
- from models.transformer_2d_flow import Transformer2DModel
21
- from transformers import AutoFeatureExtractor, Wav2Vec2BertModel,HubertModel
22
- # from tools.get_mulan import get_mulan
23
- from third_party.wespeaker.extract_embd import XVECModel
24
- # from libs.rvq2 import RVQEmbedding
25
- from libs.rvq.descript_quantize3_4layer_freezelayer1 import ResidualVectorQuantize
26
-
27
- from models_gpt.models.gpt2_rope2_time_new_correct_mask_noncasual_reflow import GPT2Model
28
- from models_gpt.models.gpt2_config import GPT2Config
29
-
30
- from torch.cuda.amp import autocast
31
-
32
-
33
- from our_MERT_BESTRQ.test import load_model
34
-
35
- class HubertModelWithFinalProj(HubertModel):
36
- def __init__(self, config):
37
- super().__init__(config)
38
-
39
- # The final projection layer is only used for backward compatibility.
40
- # Following https://github.com/auspicious3000/contentvec/issues/6
41
- # Remove this layer is necessary to achieve the desired outcome.
42
- print("hidden_size:",config.hidden_size)
43
- print("classifier_proj_size:",config.classifier_proj_size)
44
- self.final_proj = nn.Linear(config.hidden_size, config.classifier_proj_size)
45
-
46
-
47
- class SampleProcessor(torch.nn.Module):
48
- def project_sample(self, x: torch.Tensor):
49
- """Project the original sample to the 'space' where the diffusion will happen."""
50
- """Project back from diffusion space to the actual sample space."""
51
- return z
52
-
53
- class Feature1DProcessor(SampleProcessor):
54
- def __init__(self, dim: int = 100, power_std = 1., \
55
- num_samples: int = 100_000, cal_num_frames: int = 600):
56
- super().__init__()
57
-
58
- self.num_samples = num_samples
59
- self.dim = dim
60
- self.power_std = power_std
61
- self.cal_num_frames = cal_num_frames
62
- self.register_buffer('counts', torch.zeros(1))
63
- self.register_buffer('sum_x', torch.zeros(dim))
64
- self.register_buffer('sum_x2', torch.zeros(dim))
65
- self.register_buffer('sum_target_x2', torch.zeros(dim))
66
- self.counts: torch.Tensor
67
- self.sum_x: torch.Tensor
68
- self.sum_x2: torch.Tensor
69
-
70
- @property
71
- def mean(self):
72
- mean = self.sum_x / self.counts
73
- if(self.counts < 10):
74
- mean = torch.zeros_like(mean)
75
- return mean
76
-
77
- @property
78
- def std(self):
79
- std = (self.sum_x2 / self.counts - self.mean**2).clamp(min=0).sqrt()
80
- if(self.counts < 10):
81
- std = torch.ones_like(std)
82
- return std
83
-
84
- @property
85
- def target_std(self):
86
- return 1
87
-
88
- def project_sample(self, x: torch.Tensor):
89
- assert x.dim() == 3
90
- if self.counts.item() < self.num_samples:
91
- self.counts += len(x)
92
- self.sum_x += x[:,:,0:self.cal_num_frames].mean(dim=(2,)).sum(dim=0)
93
- self.sum_x2 += x[:,:,0:self.cal_num_frames].pow(2).mean(dim=(2,)).sum(dim=0)
94
- rescale = (self.target_std / self.std.clamp(min=1e-12)) ** self.power_std # same output size
95
- x = (x - self.mean.view(1, -1, 1)) * rescale.view(1, -1, 1)
96
- return x
97
-
98
- def return_sample(self, x: torch.Tensor):
99
- assert x.dim() == 3
100
- rescale = (self.std / self.target_std) ** self.power_std
101
- # print(rescale, self.mean)
102
- x = x * rescale.view(1, -1, 1) + self.mean.view(1, -1, 1)
103
- return x
104
-
105
- def pad_or_tunc_tolen(prior_text_encoder_hidden_states, prior_text_mask, prior_prompt_embeds, len_size=77):
106
- if(prior_text_encoder_hidden_states.shape[1]<len_size):
107
- prior_text_encoder_hidden_states = torch.cat([prior_text_encoder_hidden_states, \
108
- torch.zeros(prior_text_mask.shape[0], len_size-prior_text_mask.shape[1], \
109
- prior_text_encoder_hidden_states.shape[2], device=prior_text_mask.device, \
110
- dtype=prior_text_encoder_hidden_states.dtype)],1)
111
- prior_text_mask = torch.cat([prior_text_mask, torch.zeros(prior_text_mask.shape[0], len_size-prior_text_mask.shape[1], device=prior_text_mask.device, dtype=prior_text_mask.dtype)],1)
112
- else:
113
- prior_text_encoder_hidden_states = prior_text_encoder_hidden_states[:,0:len_size]
114
- prior_text_mask = prior_text_mask[:,0:len_size]
115
- prior_text_encoder_hidden_states = prior_text_encoder_hidden_states.permute(0,2,1).contiguous()
116
- return prior_text_encoder_hidden_states, prior_text_mask, prior_prompt_embeds
117
-
118
- class BASECFM(torch.nn.Module, ABC):
119
- def __init__(
120
- self,
121
- estimator,
122
- mlp,
123
- ssl_layer
124
- ):
125
- super().__init__()
126
- self.sigma_min = 1e-4
127
-
128
- self.estimator = estimator
129
- self.mlp = mlp
130
- self.ssl_layer = ssl_layer
131
-
132
- @torch.inference_mode()
133
- def forward(self, mu, n_timesteps, temperature=1.0):
134
- """Forward diffusion
135
-
136
- Args:
137
- mu (torch.Tensor): output of encoder
138
- shape: (batch_size, n_channels, mel_timesteps, n_feats)
139
- n_timesteps (int): number of diffusion steps
140
- temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
141
-
142
- Returns:
143
- sample: generated mel-spectrogram
144
- shape: (batch_size, n_channels, mel_timesteps, n_feats)
145
- """
146
- z = torch.randn_like(mu) * temperature
147
- t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device)
148
- return self.solve_euler(z, t_span=t_span)
149
-
150
- def solve_euler(self, x, latent_mask_input,incontext_x, incontext_length, t_span, mu,attention_mask, guidance_scale):
151
- """
152
- Fixed euler solver for ODEs.
153
- Args:
154
- x (torch.Tensor): random noise
155
- t_span (torch.Tensor): n_timesteps interpolated
156
- shape: (n_timesteps + 1,)
157
- mu (torch.Tensor): output of encoder
158
- shape: (batch_size, n_channels, mel_timesteps, n_feats)
159
- """
160
- t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
161
- noise = x.clone()
162
-
163
- # I am storing this because I can later plot it by putting a debugger here and saving it to a file
164
- # Or in future might add like a return_all_steps flag
165
- sol = []
166
-
167
- for step in tqdm(range(1, len(t_span))):
168
- print("incontext_x.shape:",incontext_x.shape)
169
- print("noise.shape:",noise.shape)
170
- print("t.shape:",t.shape)
171
- x[:,0:incontext_length,:] = (1 - (1 - self.sigma_min) * t) * noise[:,0:incontext_length,:] + t * incontext_x[:,0:incontext_length,:]
172
- if(guidance_scale > 1.0):
173
-
174
- model_input = torch.cat([ \
175
- torch.cat([latent_mask_input, latent_mask_input], 0), \
176
- torch.cat([incontext_x, incontext_x], 0), \
177
- torch.cat([torch.zeros_like(mu), mu], 0), \
178
- torch.cat([x, x], 0), \
179
- ], 2)
180
- timestep=t.unsqueeze(-1).repeat(2)
181
-
182
- dphi_dt = self.estimator(inputs_embeds=model_input, attention_mask=attention_mask,time_step=timestep).last_hidden_state
183
- dphi_dt_uncond, dhpi_dt_cond = dphi_dt.chunk(2,0)
184
- dphi_dt = dphi_dt_uncond + guidance_scale * (dhpi_dt_cond - dphi_dt_uncond)
185
- else:
186
- model_input = torch.cat([latent_mask_input, incontext_x, mu, x], 2)
187
- timestep=t.unsqueeze(-1)
188
- dphi_dt = self.estimator(inputs_embeds=model_input, attention_mask=attention_mask,time_step=timestep).last_hidden_state
189
-
190
- dphi_dt = dphi_dt[: ,:, -x.shape[2]:]
191
- print("dphi_dt.shape:",dphi_dt.shape)
192
- print("x.shape:",x.shape)
193
-
194
- x = x + dt * dphi_dt
195
- t = t + dt
196
- sol.append(x)
197
- if step < len(t_span) - 1:
198
- dt = t_span[step + 1] - t
199
-
200
- return sol[-1]
201
-
202
- def projection_loss(self,hidden_proj, bestrq_emb):
203
- bsz = hidden_proj.shape[0]
204
-
205
- hidden_proj_normalized = F.normalize(hidden_proj, dim=-1)
206
- bestrq_emb_normalized = F.normalize(bestrq_emb, dim=-1)
207
-
208
- proj_loss = -(hidden_proj_normalized * bestrq_emb_normalized).sum(dim=-1)
209
- proj_loss = 1+proj_loss.mean()
210
-
211
- return proj_loss
212
-
213
- def compute_loss(self, x1, mu, latent_masks,attention_mask,wav2vec_embeds, validation_mode=False):
214
- """Computes diffusion loss
215
-
216
- Args:
217
- x1 (torch.Tensor): Target
218
- shape: (batch_size, n_channels, mel_timesteps, n_feats)
219
- mu (torch.Tensor): output of encoder
220
- shape: (batch_size, n_channels, mel_timesteps, n_feats)
221
-
222
- Returns:
223
- loss: conditional flow matching loss
224
- y: conditional flow
225
- shape: (batch_size, n_channels, mel_timesteps, n_feats)
226
- """
227
- b = mu[0].shape[0]
228
- len_x = x1.shape[2]
229
- # random timestep
230
- if(validation_mode):
231
- t = torch.ones([b, 1, 1], device=mu[0].device, dtype=mu[0].dtype) * 0.5
232
- else:
233
- t = torch.rand([b, 1, 1], device=mu[0].device, dtype=mu[0].dtype)
234
- # sample noise p(x_0)
235
- z = torch.randn_like(x1)
236
-
237
- y = (1 - (1 - self.sigma_min) * t) * z + t * x1
238
- u = x1 - (1 - self.sigma_min) * z
239
- # print("y.shape:",y.shape)
240
- #self.unet(inputs_embeds=model_input, attention_mask=attention_mask,encoder_hidden_states=text_embedding,encoder_attention_mask=txt_attn_mask,time_step=timesteps).last_hidden_state
241
- model_input = torch.cat([*mu,y], 2)
242
- t=t.squeeze(-1).squeeze(-1)
243
- # print("model_input.shape:",model_input.shape)
244
- # print("attention_mask.shape:",attention_mask.shape)
245
- out = self.estimator(inputs_embeds=model_input, attention_mask=attention_mask,time_step=t,output_hidden_states=True)
246
- hidden_layer = out.hidden_states[self.ssl_layer]
247
- hidden_proj = self.mlp(hidden_layer)
248
- # print("hidden_proj.shape:",hidden_proj.shape)
249
- # print("mert_emb.shape:",mert_emb.shape)
250
- # exit()
251
-
252
-
253
- out = out.last_hidden_state
254
-
255
- out=out[:,:,-len_x:]
256
- # out=self.proj_out(out)
257
-
258
- weight = (latent_masks > 1.5).unsqueeze(-1).repeat(1, 1, out.shape[-1]).float() + (latent_masks < 0.5).unsqueeze(-1).repeat(1, 1, out.shape[-1]).float() * 0.01
259
- # print("out.shape",out.shape)
260
- # print("u.shape",u.shape)
261
- loss_re = F.mse_loss(out * weight, u * weight, reduction="sum") / weight.sum()
262
- # print("hidden_proj.shape:",hidden_proj.shape)
263
- # print("wav2vec_embeds.shape:",wav2vec_embeds.shape)
264
- loss_cos = self.projection_loss(hidden_proj, wav2vec_embeds)
265
- loss = loss_re + loss_cos * 0.5
266
- # print("loss_cos:",loss_cos,loss_cos.device)
267
- print("loss:",loss,loss.device)
268
- # exit()
269
- return loss, loss_re, loss_cos
270
-
271
- class PromptCondAudioDiffusion(nn.Module):
272
- def __init__(
273
- self,
274
- num_channels,
275
- unet_model_name=None,
276
- unet_model_config_path=None,
277
- snr_gamma=None,
278
- hubert_layer=None,
279
- ssl_layer=None,
280
- uncondition=True,
281
- out_paint=False,
282
- ):
283
- super().__init__()
284
-
285
- assert unet_model_name is not None or unet_model_config_path is not None, "Either UNet pretrain model name or a config file path is required"
286
-
287
- self.unet_model_name = unet_model_name
288
- self.unet_model_config_path = unet_model_config_path
289
- self.snr_gamma = snr_gamma
290
- self.uncondition = uncondition
291
- self.num_channels = num_channels
292
- self.hubert_layer = hubert_layer
293
- self.ssl_layer = ssl_layer
294
-
295
- # https://huggingface.co/docs/diffusers/v0.14.0/en/api/schedulers/overview
296
- self.normfeat = Feature1DProcessor(dim=64)
297
-
298
- self.sample_rate = 48000
299
- self.num_samples_perseg = self.sample_rate * 20 // 1000
300
- self.rsp48toclap = torchaudio.transforms.Resample(48000, 24000)
301
- self.rsq48towav2vec = torchaudio.transforms.Resample(48000, 16000)
302
- # self.wav2vec = Wav2Vec2BertModel.from_pretrained("facebook/w2v-bert-2.0", trust_remote_code=True)
303
- # self.wav2vec_processor = AutoFeatureExtractor.from_pretrained("facebook/w2v-bert-2.0", trust_remote_code=True)
304
- self.bestrq = load_model(
305
- model_dir='path/to/our-MERT/mert_fairseq',
306
- checkpoint_dir='checkpoint-120000.pt',
307
- )
308
- self.rsq48tobestrq = torchaudio.transforms.Resample(48000, 24000)
309
- self.rsq48tohubert = torchaudio.transforms.Resample(48000, 16000)
310
- for v in self.bestrq.parameters():v.requires_grad = False
311
- self.rvq_bestrq_emb = ResidualVectorQuantize(input_dim = 1024, n_codebooks = 4, codebook_size = 16_384, codebook_dim = 32, quantizer_dropout = 0.0, stale_tolerance=200)
312
- # for v in self.rvq_bestrq_emb.parameters():
313
- # print(v)
314
- freeze_parameters='quantizers.0'
315
- for name, param in self.rvq_bestrq_emb.named_parameters():
316
- if freeze_parameters in name:
317
- param.requires_grad = False
318
- print("Freezing RVQ parameters:", name)
319
- self.hubert = HubertModelWithFinalProj.from_pretrained("huggingface_cache/models--lengyue233--content-vec-best/snapshots/c0b9ba13db21beaa4053faae94c102ebe326fd68")
320
- for v in self.hubert.parameters():v.requires_grad = False
321
- self.zero_cond_embedding1 = nn.Parameter(torch.randn(32*32,))
322
- # self.xvecmodel = XVECModel()
323
- config = GPT2Config(n_positions=1000,n_layer=39,n_head=30,n_embd=1200)
324
- unet = GPT2Model(config)
325
- mlp = nn.Sequential(
326
- nn.Linear(1200, 1024),
327
- nn.SiLU(),
328
- nn.Linear(1024, 1024),
329
- nn.SiLU(),
330
- nn.Linear(1024, 768)
331
- )
332
- self.set_from = "random"
333
- self.cfm_wrapper = BASECFM(unet, mlp,self.ssl_layer)
334
- self.mask_emb = torch.nn.Embedding(3, 48)
335
- print("Transformer initialized from pretrain.")
336
- torch.cuda.empty_cache()
337
- # self.unet.set_attn_processor(AttnProcessor2_0())
338
- # self.unet.set_use_memory_efficient_attention_xformers(True)
339
-
340
- # self.start_embedding = nn.Parameter(torch.randn(1,1024))
341
- # self.end_embedding = nn.Parameter(torch.randn(1,1024))
342
-
343
- def compute_snr(self, timesteps):
344
- """
345
- Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
346
- """
347
- alphas_cumprod = self.noise_scheduler.alphas_cumprod
348
- sqrt_alphas_cumprod = alphas_cumprod**0.5
349
- sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
350
-
351
- # Expand the tensors.
352
- # Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026
353
- sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
354
- while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape):
355
- sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
356
- alpha = sqrt_alphas_cumprod.expand(timesteps.shape)
357
-
358
- sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
359
- while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape):
360
- sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None]
361
- sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape)
362
-
363
- # Compute SNR.
364
- snr = (alpha / sigma) ** 2
365
- return snr
366
-
367
- def preprocess_audio(self, input_audios, threshold=0.9):
368
- assert len(input_audios.shape) == 2, input_audios.shape
369
- norm_value = torch.ones_like(input_audios[:,0])
370
- max_volume = input_audios.abs().max(dim=-1)[0]
371
- norm_value[max_volume>threshold] = max_volume[max_volume>threshold] / threshold
372
- return input_audios/norm_value.unsqueeze(-1)
373
-
374
- def extract_wav2vec_embeds(self, input_audios,output_len):
375
- wav2vec_stride = 2
376
-
377
- wav2vec_embeds = self.hubert(self.rsq48tohubert(input_audios), output_hidden_states=True).hidden_states # 1, 4096, 1024
378
- # print(wav2vec_embeds)
379
- # print("audio.shape:",input_audios.shape)
380
- wav2vec_embeds_last=wav2vec_embeds[self.hubert_layer]
381
- # print("wav2vec_embeds_last.shape:",wav2vec_embeds_last.shape)
382
- wav2vec_embeds_last=torch.nn.functional.interpolate(wav2vec_embeds_last.permute(0, 2, 1), size=output_len, mode='linear', align_corners=False).permute(0, 2, 1)
383
- return wav2vec_embeds_last
384
-
385
- def extract_mert_embeds(self, input_audios):
386
- prompt_stride = 3
387
- inputs = self.clap_embd_extractor.mulan.audio.processor(self.rsp48toclap(input_audios), sampling_rate=self.clap_embd_extractor.mulan.audio.sr, return_tensors="pt")
388
- input_values = inputs['input_values'].squeeze(0).to(input_audios.device, dtype = input_audios.dtype)
389
- prompt_embeds = self.clap_embd_extractor.mulan.audio.model(input_values, output_hidden_states=True).hidden_states # batch_size, Time steps, 1024
390
- mert_emb= prompt_embeds[-1]
391
- mert_emb = torch.nn.functional.interpolate(mert_emb.permute(0, 2, 1), size=500, mode='linear', align_corners=False).permute(0, 2, 1)
392
-
393
- return mert_emb
394
-
395
- def extract_bestrq_embeds(self, input_audio_0,input_audio_1,layer):
396
- self.bestrq.eval()
397
- # print("audio shape:",input_audio_0.shape)
398
- input_wav_mean = (input_audio_0 + input_audio_1) / 2.0
399
- # print("input_wav_mean.shape:",input_wav_mean.shape)
400
- # input_wav_mean = torch.randn(2,1720320*2).to(input_audio_0.device)
401
- input_wav_mean = self.bestrq(self.rsq48tobestrq(input_wav_mean), features_only = True)
402
- layer_results = input_wav_mean['layer_results']
403
- # print("layer_results.shape:",layer_results[layer].shape)
404
- bestrq_emb = layer_results[layer]
405
- bestrq_emb = bestrq_emb.permute(0,2,1).contiguous()
406
- #[b,t,1024] t=t/960
407
- #35.84s->batch,896,1024
408
- return bestrq_emb
409
-
410
-
411
- def extract_spk_embeds(self, input_audios):
412
- spk_embeds = self.xvecmodel(self.rsq48towav2vec(input_audios))
413
- spk_embeds = self.spk_linear(spk_embeds).reshape(spk_embeds.shape[0], 16, 1, 32)
414
- return spk_embeds
415
-
416
- def extract_lyric_feats(self, lyric):
417
- with torch.no_grad():
418
- try:
419
- text_encoder_hidden_states, text_mask, text_prompt_embeds = self.clap_embd_extractor(texts = lyric, return_one=False)
420
- except:
421
- text_encoder_hidden_states, text_mask, text_prompt_embeds = self.clap_embd_extractor(texts = [""] * len(lyric), return_one=False)
422
- text_encoder_hidden_states = text_encoder_hidden_states.to(self.device)
423
- text_mask = text_mask.to(self.device)
424
- text_encoder_hidden_states, text_mask, text_prompt_embeds = \
425
- pad_or_tunc_tolen(text_encoder_hidden_states, text_mask, text_prompt_embeds)
426
- text_encoder_hidden_states = text_encoder_hidden_states.permute(0,2,1).contiguous()
427
- return text_encoder_hidden_states, text_mask
428
-
429
- def extract_energy_bar(self, input_audios):
430
- if(input_audios.shape[-1] % self.num_samples_perseg > 0):
431
- energy_bar = input_audios[:,:-1 * (input_audios.shape[-1] % self.num_samples_perseg)].reshape(input_audios.shape[0],-1,self.num_samples_perseg)
432
- else:
433
- energy_bar = input_audios.reshape(input_audios.shape[0],-1,self.num_samples_perseg)
434
- energy_bar = (energy_bar.pow(2.0).mean(-1).sqrt() + 1e-6).log10() * 20 # B T
435
- energy_bar = (energy_bar / 2.0 + 16).clamp(0,16).int()
436
- energy_embedding = self.energy_embedding(energy_bar)
437
- energy_embedding = energy_embedding.view(energy_embedding.shape[0], energy_embedding.shape[1] // 2, 2, 32).reshape(energy_embedding.shape[0], energy_embedding.shape[1] // 2, 64).permute(0,2,1) # b 128 t
438
- return energy_embedding
439
-
440
- def forward(self, input_audios, lyric, latents, latent_masks, validation_mode=False, \
441
- additional_feats = ['spk', 'lyric'], \
442
- train_rvq=True, train_ssl=False,layer=5):
443
- if not hasattr(self,"device"):
444
- self.device = input_audios.device
445
- if not hasattr(self,"dtype"):
446
- self.dtype = input_audios.dtype
447
- device = self.device
448
- input_audio_0 = input_audios[:,0,:]
449
- input_audio_1 = input_audios[:,1,:]
450
- input_audio_0 = self.preprocess_audio(input_audio_0)
451
- input_audio_1 = self.preprocess_audio(input_audio_1)
452
- input_audios_wav2vec = (input_audio_0 + input_audio_1) / 2.0
453
- # energy_embedding = self.extract_energy_bar(input_audios)
454
- # print("energy_embedding.shape:",energy_embedding.shape)
455
- # with autocast(enabled=False):
456
- if(train_ssl):
457
- self.wav2vec.train()
458
- wav2vec_embeds = self.extract_wav2vec_embeds(input_audios)
459
- self.clap_embd_extractor.train()
460
- prompt_embeds = self.extract_mert_embeds(input_audios)
461
- if('spk' in additional_feats):
462
- self.xvecmodel.train()
463
- spk_embeds = self.extract_spk_embeds(input_audios).repeat(1,1,prompt_embeds.shape[-1]//2,1)
464
- else:
465
- with torch.no_grad():
466
- with autocast(enabled=False):
467
- bestrq_emb = self.extract_bestrq_embeds(input_audio_0,input_audio_1,layer)
468
- # mert_emb = self.extract_mert_embeds(input_audios_mert)
469
-
470
- wav2vec_embeds = self.extract_wav2vec_embeds(input_audios_wav2vec,bestrq_emb.shape[2])
471
-
472
- bestrq_emb = bestrq_emb.detach()
473
- if('lyric' in additional_feats):
474
- text_encoder_hidden_states, text_mask = self.extract_lyric_feats(lyric)
475
- else:
476
- text_encoder_hidden_states, text_mask = None, None
477
-
478
-
479
- if(train_rvq):
480
- random_num=random.random()
481
- if(random_num<0.6):
482
- rvq_layer = 1
483
- elif(random_num<0.8):
484
- rvq_layer = 2
485
- else:
486
- rvq_layer = 4
487
- quantized_bestrq_emb, _, _, commitment_loss_bestrq_emb, codebook_loss_bestrq_emb,_ = self.rvq_bestrq_emb(bestrq_emb,n_quantizers=rvq_layer) # b,d,t
488
- else:
489
- bestrq_emb = bestrq_emb.float()
490
- self.rvq_bestrq_emb.eval()
491
- # with autocast(enabled=False):
492
- quantized_bestrq_emb, _, _, commitment_loss_bestrq_emb, codebook_loss_bestrq_emb,_ = self.rvq_bestrq_emb(bestrq_emb) # b,d,t
493
- commitment_loss_bestrq_emb = commitment_loss_bestrq_emb.detach()
494
- codebook_loss_bestrq_emb = codebook_loss_bestrq_emb.detach()
495
- quantized_bestrq_emb = quantized_bestrq_emb.detach()
496
-
497
- commitment_loss = commitment_loss_bestrq_emb
498
- codebook_loss = codebook_loss_bestrq_emb
499
-
500
-
501
- alpha=1
502
- quantized_bestrq_emb = quantized_bestrq_emb * alpha + bestrq_emb * (1-alpha)
503
-
504
- # print("quantized_bestrq_emb.shape:",quantized_bestrq_emb.shape)
505
- # print("latent_masks.shape:",latent_masks.shape)
506
- # quantized_bestrq_emb = torch.nn.functional.interpolate(quantized_bestrq_emb, size=(int(quantized_bestrq_emb.shape[-1]/999*937),), mode='linear', align_corners=True)
507
-
508
-
509
-
510
- scenario = np.random.choice(['start_seg', 'other_seg'])
511
- if(scenario == 'other_seg'):
512
- for binx in range(input_audios.shape[0]):
513
- # latent_masks[binx,0:64] = 1
514
- latent_masks[binx,0:random.randint(64,128)] = 1
515
- quantized_bestrq_emb = quantized_bestrq_emb.permute(0,2,1).contiguous()
516
- # print("quantized_bestrq_emb.shape:",quantized_bestrq_emb.shape)
517
- # print("quantized_bestrq_emb1.shape:",quantized_bestrq_emb.shape)
518
- # print("latent_masks.shape:",latent_masks.shape)
519
- quantized_bestrq_emb = (latent_masks > 0.5).unsqueeze(-1) * quantized_bestrq_emb \
520
- + (latent_masks < 0.5).unsqueeze(-1) * self.zero_cond_embedding1.reshape(1,1,1024)
521
-
522
-
523
-
524
-
525
- if self.uncondition:
526
- mask_indices = [k for k in range(quantized_bestrq_emb.shape[0]) if random.random() < 0.1]
527
- if len(mask_indices) > 0:
528
- quantized_bestrq_emb[mask_indices] = 0
529
- # print("latents.shape:",latents.shape)
530
- latents = latents.permute(0,2,1).contiguous()
531
- latents = self.normfeat.project_sample(latents)
532
- latents = latents.permute(0,2,1).contiguous()
533
- incontext_latents = latents * ((latent_masks > 0.5) * (latent_masks < 1.5)).unsqueeze(-1).float()
534
- attention_mask=(latent_masks > 0.5)
535
- B, L = attention_mask.size()
536
- attention_mask = attention_mask.view(B, 1, L)
537
- attention_mask = attention_mask * attention_mask.transpose(-1, -2)
538
- attention_mask = attention_mask.unsqueeze(1)
539
- # print("incontext_latents.shape:",incontext_latents.shape)
540
- # print("quantized_bestrq_emb.shape:",quantized_bestrq_emb.shape)
541
- latent_mask_input = self.mask_emb(latent_masks)
542
- #64+48+64+1024
543
- loss,loss_re, loss_cos = self.cfm_wrapper.compute_loss(latents, [latent_mask_input,incontext_latents, quantized_bestrq_emb], latent_masks,attention_mask,wav2vec_embeds, validation_mode=validation_mode)
544
- return loss,loss_re, loss_cos, commitment_loss.mean(), codebook_loss.mean()
545
-
546
- def init_device_dtype(self, device, dtype):
547
- self.device = device
548
- self.dtype = dtype
549
-
550
- @torch.no_grad()
551
- def fetch_codes(self, input_audios, additional_feats,layer,rvq_num=1):
552
- input_audio_0 = input_audios[[0],:]
553
- input_audio_1 = input_audios[[1],:]
554
- input_audio_0 = self.preprocess_audio(input_audio_0)
555
- input_audio_1 = self.preprocess_audio(input_audio_1)
556
-
557
- self.bestrq.eval()
558
-
559
- # bestrq_middle,bestrq_last = self.extract_bestrq_embeds(input_audios)
560
- # bestrq_middle = bestrq_middle.detach()
561
- # bestrq_last = bestrq_last.detach()
562
- bestrq_emb = self.extract_bestrq_embeds(input_audio_0,input_audio_1,layer)
563
- bestrq_emb = bestrq_emb.detach()
564
-
565
- # self.rvq_bestrq_middle.eval()
566
- # quantized_bestrq_middle, codes_bestrq_middle, *_ = self.rvq_bestrq_middle(bestrq_middle) # b,d,t
567
- # self.rvq_bestrq_last.eval()
568
- # quantized_bestrq_last, codes_bestrq_last, *_ = self.rvq_bestrq_last(bestrq_last) # b,d,t
569
-
570
- self.rvq_bestrq_emb.eval()
571
- quantized_bestrq_emb, codes_bestrq_emb, *_ = self.rvq_bestrq_emb(bestrq_emb)
572
- codes_bestrq_emb = codes_bestrq_emb[:,:rvq_num,:]
573
- # print("codes_bestrq_emb.shape:",codes_bestrq_emb.shape)
574
- # exit()
575
-
576
-
577
- if('spk' in additional_feats):
578
- self.xvecmodel.eval()
579
- spk_embeds = self.extract_spk_embeds(input_audios)
580
- else:
581
- spk_embeds = None
582
-
583
- # return [codes_prompt, codes_wav2vec], [prompt_embeds, wav2vec_embeds], spk_embeds
584
- # return [codes_prompt_7, codes_prompt_13, codes_prompt_20, codes_wav2vec_half, codes_wav2vec_last], [prompt_embeds_7, prompt_embeds_13, prompt_embeds_20, wav2vec_embeds_half, wav2vec_embeds_last], spk_embeds
585
- # return [codes_bestrq_middle, codes_bestrq_last], [bestrq_middle, bestrq_last], spk_embeds
586
- return [codes_bestrq_emb], [bestrq_emb], spk_embeds
587
- # return [codes_prompt_13, codes_wav2vec_last], [prompt_embeds_13, wav2vec_embeds_last], spk_embeds
588
-
589
- @torch.no_grad()
590
- def fetch_codes_batch(self, input_audios, additional_feats,layer,rvq_num=1):
591
- input_audio_0 = input_audios[:,0,:]
592
- input_audio_1 = input_audios[:,1,:]
593
- input_audio_0 = self.preprocess_audio(input_audio_0)
594
- input_audio_1 = self.preprocess_audio(input_audio_1)
595
-
596
- self.bestrq.eval()
597
-
598
- # bestrq_middle,bestrq_last = self.extract_bestrq_embeds(input_audios)
599
- # bestrq_middle = bestrq_middle.detach()
600
- # bestrq_last = bestrq_last.detach()
601
- bestrq_emb = self.extract_bestrq_embeds(input_audio_0,input_audio_1,layer)
602
- bestrq_emb = bestrq_emb.detach()
603
-
604
- # self.rvq_bestrq_middle.eval()
605
- # quantized_bestrq_middle, codes_bestrq_middle, *_ = self.rvq_bestrq_middle(bestrq_middle) # b,d,t
606
- # self.rvq_bestrq_last.eval()
607
- # quantized_bestrq_last, codes_bestrq_last, *_ = self.rvq_bestrq_last(bestrq_last) # b,d,t
608
-
609
- self.rvq_bestrq_emb.eval()
610
- quantized_bestrq_emb, codes_bestrq_emb, *_ = self.rvq_bestrq_emb(bestrq_emb)
611
- # print("codes_bestrq_emb.shape:",codes_bestrq_emb.shape)
612
- codes_bestrq_emb = codes_bestrq_emb[:,:rvq_num,:]
613
- # print("codes_bestrq_emb.shape:",codes_bestrq_emb.shape)
614
- # exit()
615
-
616
-
617
- if('spk' in additional_feats):
618
- self.xvecmodel.eval()
619
- spk_embeds = self.extract_spk_embeds(input_audios)
620
- else:
621
- spk_embeds = None
622
-
623
- # return [codes_prompt, codes_wav2vec], [prompt_embeds, wav2vec_embeds], spk_embeds
624
- # return [codes_prompt_7, codes_prompt_13, codes_prompt_20, codes_wav2vec_half, codes_wav2vec_last], [prompt_embeds_7, prompt_embeds_13, prompt_embeds_20, wav2vec_embeds_half, wav2vec_embeds_last], spk_embeds
625
- # return [codes_bestrq_middle, codes_bestrq_last], [bestrq_middle, bestrq_last], spk_embeds
626
- return [codes_bestrq_emb], [bestrq_emb], spk_embeds
627
- # return [codes_prompt_13, codes_wav2vec_last], [prompt_embeds_13, wav2vec_embeds_last], spk_embeds
628
-
629
- @torch.no_grad()
630
- def fetch_codes_batch_ds(self, input_audios, additional_feats, layer, rvq_num=1, ds=250):
631
- input_audio_0 = input_audios[:,0,:]
632
- input_audio_1 = input_audios[:,1,:]
633
- input_audio_0 = self.preprocess_audio(input_audio_0)
634
- input_audio_1 = self.preprocess_audio(input_audio_1)
635
-
636
- self.bestrq.eval()
637
-
638
- # bestrq_middle,bestrq_last = self.extract_bestrq_embeds(input_audios)
639
- # bestrq_middle = bestrq_middle.detach()
640
- # bestrq_last = bestrq_last.detach()
641
- bestrq_emb = self.extract_bestrq_embeds(input_audio_0,input_audio_1,layer)
642
- bestrq_emb = bestrq_emb.detach()
643
-
644
- # self.rvq_bestrq_middle.eval()
645
- # quantized_bestrq_middle, codes_bestrq_middle, *_ = self.rvq_bestrq_middle(bestrq_middle) # b,d,t
646
- # self.rvq_bestrq_last.eval()
647
- # quantized_bestrq_last, codes_bestrq_last, *_ = self.rvq_bestrq_last(bestrq_last) # b,d,t
648
-
649
- self.rvq_bestrq_emb.eval()
650
- bestrq_emb = torch.nn.functional.avg_pool1d(bestrq_emb, kernel_size=ds, stride=ds)
651
- quantized_bestrq_emb, codes_bestrq_emb, *_ = self.rvq_bestrq_emb(bestrq_emb)
652
- # print("codes_bestrq_emb.shape:",codes_bestrq_emb.shape)
653
- codes_bestrq_emb = codes_bestrq_emb[:,:rvq_num,:]
654
- # print("codes_bestrq_emb.shape:",codes_bestrq_emb.shape)
655
- # exit()
656
-
657
-
658
- if('spk' in additional_feats):
659
- self.xvecmodel.eval()
660
- spk_embeds = self.extract_spk_embeds(input_audios)
661
- else:
662
- spk_embeds = None
663
-
664
- # return [codes_prompt, codes_wav2vec], [prompt_embeds, wav2vec_embeds], spk_embeds
665
- # return [codes_prompt_7, codes_prompt_13, codes_prompt_20, codes_wav2vec_half, codes_wav2vec_last], [prompt_embeds_7, prompt_embeds_13, prompt_embeds_20, wav2vec_embeds_half, wav2vec_embeds_last], spk_embeds
666
- # return [codes_bestrq_middle, codes_bestrq_last], [bestrq_middle, bestrq_last], spk_embeds
667
- return [codes_bestrq_emb], [bestrq_emb], spk_embeds
668
- # return [codes_prompt_13, codes_wav2vec_last], [prompt_embeds_13, wav2vec_embeds_last], spk_embeds
669
-
670
- @torch.no_grad()
671
- def inference_codes(self, codes, spk_embeds, true_latents, latent_length, additional_feats, incontext_length=127,
672
- guidance_scale=2, num_steps=20,
673
- disable_progress=True, scenario='start_seg'):
674
- classifier_free_guidance = guidance_scale > 1.0
675
- device = self.device
676
- dtype = self.dtype
677
- # codes_bestrq_middle, codes_bestrq_last = codes
678
- codes_bestrq_emb = codes[0]
679
-
680
-
681
- batch_size = codes_bestrq_emb.shape[0]
682
-
683
-
684
- quantized_bestrq_emb,_,_=self.rvq_bestrq_emb.from_codes(codes_bestrq_emb)
685
- # quantized_bestrq_emb = torch.nn.functional.interpolate(quantized_bestrq_emb, size=(int(quantized_bestrq_emb.shape[-1]/999*937),), mode='linear', align_corners=True)
686
- quantized_bestrq_emb = quantized_bestrq_emb.permute(0,2,1).contiguous()
687
- print("quantized_bestrq_emb.shape:",quantized_bestrq_emb.shape)
688
- # quantized_bestrq_emb = torch.nn.functional.interpolate(quantized_bestrq_emb, size=(int(quantized_bestrq_emb.shape[-1]/999*937),), mode='linear', align_corners=True)
689
-
690
-
691
-
692
-
693
- if('spk' in additional_feats):
694
- spk_embeds = spk_embeds.repeat(1,1,quantized_bestrq_emb.shape[-2],1).detach()
695
-
696
- num_frames = quantized_bestrq_emb.shape[1]
697
-
698
- num_channels_latents = self.num_channels
699
- shape = (batch_size, num_frames, 64)
700
- latents = randn_tensor(shape, generator=None, device=device, dtype=dtype)
701
-
702
-
703
-
704
- latent_masks = torch.zeros(latents.shape[0], latents.shape[1], dtype=torch.int64, device=latents.device)
705
- latent_masks[:,0:latent_length] = 2
706
- if(scenario=='other_seg'):
707
- latent_masks[:,0:incontext_length] = 1
708
-
709
-
710
-
711
- quantized_bestrq_emb = (latent_masks > 0.5).unsqueeze(-1) * quantized_bestrq_emb \
712
- + (latent_masks < 0.5).unsqueeze(-1) * self.zero_cond_embedding1.reshape(1,1,1024)
713
- true_latents = true_latents.permute(0,2,1).contiguous()
714
- true_latents = self.normfeat.project_sample(true_latents)
715
- true_latents = true_latents.permute(0,2,1).contiguous()
716
- incontext_latents = true_latents * ((latent_masks > 0.5) * (latent_masks < 1.5)).unsqueeze(-1).float()
717
- incontext_length = ((latent_masks > 0.5) * (latent_masks < 1.5)).sum(-1)[0]
718
-
719
-
720
- attention_mask=(latent_masks > 0.5)
721
- B, L = attention_mask.size()
722
- attention_mask = attention_mask.view(B, 1, L)
723
- attention_mask = attention_mask * attention_mask.transpose(-1, -2)
724
- attention_mask = attention_mask.unsqueeze(1)
725
- latent_mask_input = self.mask_emb(latent_masks)
726
-
727
- if('spk' in additional_feats):
728
- # additional_model_input = torch.cat([quantized_bestrq_middle, quantized_bestrq_last, spk_embeds],1)
729
- additional_model_input = torch.cat([quantized_bestrq_emb, spk_embeds],1)
730
- else:
731
- # additional_model_input = torch.cat([quantized_bestrq_middle, quantized_bestrq_last],1)
732
- additional_model_input = torch.cat([quantized_bestrq_emb],1)
733
-
734
- temperature = 1.0
735
- t_span = torch.linspace(0, 1, num_steps + 1, device=quantized_bestrq_emb.device)
736
- latents = self.cfm_wrapper.solve_euler(latents * temperature, latent_mask_input,incontext_latents, incontext_length, t_span, additional_model_input,attention_mask, guidance_scale)
737
-
738
- latents[:,0:incontext_length,:] = incontext_latents[:,0:incontext_length,:]
739
- latents = latents.permute(0,2,1).contiguous()
740
- latents = self.normfeat.return_sample(latents)
741
- # latents = latents.permute(0,2,1).contiguous()
742
- return latents
743
-
744
- @torch.no_grad()
745
- def inference(self, input_audios, lyric, true_latents, latent_length, additional_feats, guidance_scale=2, num_steps=20,
746
- disable_progress=True,layer=5,scenario='start_seg',rvq_num=1):
747
- codes, embeds, spk_embeds = self.fetch_codes(input_audios, additional_feats,layer,rvq_num)
748
-
749
- latents = self.inference_codes(codes, spk_embeds, true_latents, latent_length, additional_feats, \
750
- guidance_scale=guidance_scale, num_steps=num_steps, \
751
- disable_progress=disable_progress,scenario=scenario)
752
- return latents
753
-
754
- @torch.no_grad()
755
- def inference_rtf(self, input_audios, lyric, true_latents, latent_length, additional_feats, guidance_scale=2, num_steps=20,
756
- disable_progress=True,layer=5,scenario='start_seg'):
757
- codes, embeds, spk_embeds = self.fetch_codes(input_audios, additional_feats,layer)
758
- import time
759
- start = time.time()
760
- latents = self.inference_codes(codes, spk_embeds, true_latents, latent_length, additional_feats, \
761
- guidance_scale=guidance_scale, num_steps=num_steps, \
762
- disable_progress=disable_progress,scenario=scenario)
763
- return latents,time.time()-start
764
-
765
- def prepare_latents(self, batch_size, num_frames, num_channels_latents, dtype, device):
766
- divisor = 4
767
- shape = (batch_size, num_channels_latents, num_frames, 32)
768
- if(num_frames%divisor>0):
769
- num_frames = round(num_frames/float(divisor))*divisor
770
- shape = (batch_size, num_channels_latents, num_frames, 32)
771
- latents = randn_tensor(shape, generator=None, device=device, dtype=dtype)
772
- return latents
773
-
774
-
 
1
+ import yaml
2
+ import random
3
+ import inspect
4
+ import numpy as np
5
+ from tqdm import tqdm
6
+ import typing as tp
7
+ from abc import ABC
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ import torchaudio
13
+
14
+ from einops import repeat
15
+ from tools.torch_tools import wav_to_fbank
16
+
17
+ import diffusers
18
+ from diffusers.utils.torch_utils import randn_tensor
19
+ from diffusers import DDPMScheduler
20
+ from models.transformer_2d_flow import Transformer2DModel
21
+ from transformers import AutoFeatureExtractor, Wav2Vec2BertModel,HubertModel
22
+ # from tools.get_mulan import get_mulan
23
+ from third_party.wespeaker.extract_embd import XVECModel
24
+ # from libs.rvq2 import RVQEmbedding
25
+ from libs.rvq.descript_quantize3_4layer_freezelayer1 import ResidualVectorQuantize
26
+
27
+ from models_gpt.models.gpt2_rope2_time_new_correct_mask_noncasual_reflow import GPT2Model
28
+ from models_gpt.models.gpt2_config import GPT2Config
29
+
30
+ from torch.cuda.amp import autocast
31
+
32
+
33
+ from our_MERT_BESTRQ.test import load_model
34
+
35
+ class HubertModelWithFinalProj(HubertModel):
36
+ def __init__(self, config):
37
+ super().__init__(config)
38
+
39
+ # The final projection layer is only used for backward compatibility.
40
+ # Following https://github.com/auspicious3000/contentvec/issues/6
41
+ # Remove this layer is necessary to achieve the desired outcome.
42
+ print("hidden_size:",config.hidden_size)
43
+ print("classifier_proj_size:",config.classifier_proj_size)
44
+ self.final_proj = nn.Linear(config.hidden_size, config.classifier_proj_size)
45
+
46
+
47
+ class SampleProcessor(torch.nn.Module):
48
+ def project_sample(self, x: torch.Tensor):
49
+ """Project the original sample to the 'space' where the diffusion will happen."""
50
+ """Project back from diffusion space to the actual sample space."""
51
+ return z
52
+
53
+ class Feature1DProcessor(SampleProcessor):
54
+ def __init__(self, dim: int = 100, power_std = 1., \
55
+ num_samples: int = 100_000, cal_num_frames: int = 600):
56
+ super().__init__()
57
+
58
+ self.num_samples = num_samples
59
+ self.dim = dim
60
+ self.power_std = power_std
61
+ self.cal_num_frames = cal_num_frames
62
+ self.register_buffer('counts', torch.zeros(1))
63
+ self.register_buffer('sum_x', torch.zeros(dim))
64
+ self.register_buffer('sum_x2', torch.zeros(dim))
65
+ self.register_buffer('sum_target_x2', torch.zeros(dim))
66
+ self.counts: torch.Tensor
67
+ self.sum_x: torch.Tensor
68
+ self.sum_x2: torch.Tensor
69
+
70
+ @property
71
+ def mean(self):
72
+ mean = self.sum_x / self.counts
73
+ if(self.counts < 10):
74
+ mean = torch.zeros_like(mean)
75
+ return mean
76
+
77
+ @property
78
+ def std(self):
79
+ std = (self.sum_x2 / self.counts - self.mean**2).clamp(min=0).sqrt()
80
+ if(self.counts < 10):
81
+ std = torch.ones_like(std)
82
+ return std
83
+
84
+ @property
85
+ def target_std(self):
86
+ return 1
87
+
88
+ def project_sample(self, x: torch.Tensor):
89
+ assert x.dim() == 3
90
+ if self.counts.item() < self.num_samples:
91
+ self.counts += len(x)
92
+ self.sum_x += x[:,:,0:self.cal_num_frames].mean(dim=(2,)).sum(dim=0)
93
+ self.sum_x2 += x[:,:,0:self.cal_num_frames].pow(2).mean(dim=(2,)).sum(dim=0)
94
+ rescale = (self.target_std / self.std.clamp(min=1e-12)) ** self.power_std # same output size
95
+ x = (x - self.mean.view(1, -1, 1)) * rescale.view(1, -1, 1)
96
+ return x
97
+
98
+ def return_sample(self, x: torch.Tensor):
99
+ assert x.dim() == 3
100
+ rescale = (self.std / self.target_std) ** self.power_std
101
+ # print(rescale, self.mean)
102
+ x = x * rescale.view(1, -1, 1) + self.mean.view(1, -1, 1)
103
+ return x
104
+
105
+ def pad_or_tunc_tolen(prior_text_encoder_hidden_states, prior_text_mask, prior_prompt_embeds, len_size=77):
106
+ if(prior_text_encoder_hidden_states.shape[1]<len_size):
107
+ prior_text_encoder_hidden_states = torch.cat([prior_text_encoder_hidden_states, \
108
+ torch.zeros(prior_text_mask.shape[0], len_size-prior_text_mask.shape[1], \
109
+ prior_text_encoder_hidden_states.shape[2], device=prior_text_mask.device, \
110
+ dtype=prior_text_encoder_hidden_states.dtype)],1)
111
+ prior_text_mask = torch.cat([prior_text_mask, torch.zeros(prior_text_mask.shape[0], len_size-prior_text_mask.shape[1], device=prior_text_mask.device, dtype=prior_text_mask.dtype)],1)
112
+ else:
113
+ prior_text_encoder_hidden_states = prior_text_encoder_hidden_states[:,0:len_size]
114
+ prior_text_mask = prior_text_mask[:,0:len_size]
115
+ prior_text_encoder_hidden_states = prior_text_encoder_hidden_states.permute(0,2,1).contiguous()
116
+ return prior_text_encoder_hidden_states, prior_text_mask, prior_prompt_embeds
117
+
118
+ class BASECFM(torch.nn.Module, ABC):
119
+ def __init__(
120
+ self,
121
+ estimator,
122
+ mlp,
123
+ ssl_layer
124
+ ):
125
+ super().__init__()
126
+ self.sigma_min = 1e-4
127
+
128
+ self.estimator = estimator
129
+ self.mlp = mlp
130
+ self.ssl_layer = ssl_layer
131
+
132
+ @torch.inference_mode()
133
+ def forward(self, mu, n_timesteps, temperature=1.0):
134
+ """Forward diffusion
135
+
136
+ Args:
137
+ mu (torch.Tensor): output of encoder
138
+ shape: (batch_size, n_channels, mel_timesteps, n_feats)
139
+ n_timesteps (int): number of diffusion steps
140
+ temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
141
+
142
+ Returns:
143
+ sample: generated mel-spectrogram
144
+ shape: (batch_size, n_channels, mel_timesteps, n_feats)
145
+ """
146
+ z = torch.randn_like(mu) * temperature
147
+ t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device)
148
+ return self.solve_euler(z, t_span=t_span)
149
+
150
+ def solve_euler(self, x, latent_mask_input,incontext_x, incontext_length, t_span, mu,attention_mask, guidance_scale):
151
+ """
152
+ Fixed euler solver for ODEs.
153
+ Args:
154
+ x (torch.Tensor): random noise
155
+ t_span (torch.Tensor): n_timesteps interpolated
156
+ shape: (n_timesteps + 1,)
157
+ mu (torch.Tensor): output of encoder
158
+ shape: (batch_size, n_channels, mel_timesteps, n_feats)
159
+ """
160
+ t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
161
+ noise = x.clone()
162
+
163
+ # I am storing this because I can later plot it by putting a debugger here and saving it to a file
164
+ # Or in future might add like a return_all_steps flag
165
+ sol = []
166
+
167
+ for step in tqdm(range(1, len(t_span))):
168
+ print("incontext_x.shape:",incontext_x.shape)
169
+ print("noise.shape:",noise.shape)
170
+ print("t.shape:",t.shape)
171
+ x[:,0:incontext_length,:] = (1 - (1 - self.sigma_min) * t) * noise[:,0:incontext_length,:] + t * incontext_x[:,0:incontext_length,:]
172
+ if(guidance_scale > 1.0):
173
+
174
+ model_input = torch.cat([ \
175
+ torch.cat([latent_mask_input, latent_mask_input], 0), \
176
+ torch.cat([incontext_x, incontext_x], 0), \
177
+ torch.cat([torch.zeros_like(mu), mu], 0), \
178
+ torch.cat([x, x], 0), \
179
+ ], 2)
180
+ timestep=t.unsqueeze(-1).repeat(2)
181
+
182
+ dphi_dt = self.estimator(inputs_embeds=model_input, attention_mask=attention_mask,time_step=timestep).last_hidden_state
183
+ dphi_dt_uncond, dhpi_dt_cond = dphi_dt.chunk(2,0)
184
+ dphi_dt = dphi_dt_uncond + guidance_scale * (dhpi_dt_cond - dphi_dt_uncond)
185
+ else:
186
+ model_input = torch.cat([latent_mask_input, incontext_x, mu, x], 2)
187
+ timestep=t.unsqueeze(-1)
188
+ dphi_dt = self.estimator(inputs_embeds=model_input, attention_mask=attention_mask,time_step=timestep).last_hidden_state
189
+
190
+ dphi_dt = dphi_dt[: ,:, -x.shape[2]:]
191
+ print("dphi_dt.shape:",dphi_dt.shape)
192
+ print("x.shape:",x.shape)
193
+
194
+ x = x + dt * dphi_dt
195
+ t = t + dt
196
+ sol.append(x)
197
+ if step < len(t_span) - 1:
198
+ dt = t_span[step + 1] - t
199
+
200
+ return sol[-1]
201
+
202
+ def projection_loss(self,hidden_proj, bestrq_emb):
203
+ bsz = hidden_proj.shape[0]
204
+
205
+ hidden_proj_normalized = F.normalize(hidden_proj, dim=-1)
206
+ bestrq_emb_normalized = F.normalize(bestrq_emb, dim=-1)
207
+
208
+ proj_loss = -(hidden_proj_normalized * bestrq_emb_normalized).sum(dim=-1)
209
+ proj_loss = 1+proj_loss.mean()
210
+
211
+ return proj_loss
212
+
213
+ def compute_loss(self, x1, mu, latent_masks,attention_mask,wav2vec_embeds, validation_mode=False):
214
+ """Computes diffusion loss
215
+
216
+ Args:
217
+ x1 (torch.Tensor): Target
218
+ shape: (batch_size, n_channels, mel_timesteps, n_feats)
219
+ mu (torch.Tensor): output of encoder
220
+ shape: (batch_size, n_channels, mel_timesteps, n_feats)
221
+
222
+ Returns:
223
+ loss: conditional flow matching loss
224
+ y: conditional flow
225
+ shape: (batch_size, n_channels, mel_timesteps, n_feats)
226
+ """
227
+ b = mu[0].shape[0]
228
+ len_x = x1.shape[2]
229
+ # random timestep
230
+ if(validation_mode):
231
+ t = torch.ones([b, 1, 1], device=mu[0].device, dtype=mu[0].dtype) * 0.5
232
+ else:
233
+ t = torch.rand([b, 1, 1], device=mu[0].device, dtype=mu[0].dtype)
234
+ # sample noise p(x_0)
235
+ z = torch.randn_like(x1)
236
+
237
+ y = (1 - (1 - self.sigma_min) * t) * z + t * x1
238
+ u = x1 - (1 - self.sigma_min) * z
239
+ # print("y.shape:",y.shape)
240
+ #self.unet(inputs_embeds=model_input, attention_mask=attention_mask,encoder_hidden_states=text_embedding,encoder_attention_mask=txt_attn_mask,time_step=timesteps).last_hidden_state
241
+ model_input = torch.cat([*mu,y], 2)
242
+ t=t.squeeze(-1).squeeze(-1)
243
+ # print("model_input.shape:",model_input.shape)
244
+ # print("attention_mask.shape:",attention_mask.shape)
245
+ out = self.estimator(inputs_embeds=model_input, attention_mask=attention_mask,time_step=t,output_hidden_states=True)
246
+ hidden_layer = out.hidden_states[self.ssl_layer]
247
+ hidden_proj = self.mlp(hidden_layer)
248
+ # print("hidden_proj.shape:",hidden_proj.shape)
249
+ # print("mert_emb.shape:",mert_emb.shape)
250
+ # exit()
251
+
252
+
253
+ out = out.last_hidden_state
254
+
255
+ out=out[:,:,-len_x:]
256
+ # out=self.proj_out(out)
257
+
258
+ weight = (latent_masks > 1.5).unsqueeze(-1).repeat(1, 1, out.shape[-1]).float() + (latent_masks < 0.5).unsqueeze(-1).repeat(1, 1, out.shape[-1]).float() * 0.01
259
+ # print("out.shape",out.shape)
260
+ # print("u.shape",u.shape)
261
+ loss_re = F.mse_loss(out * weight, u * weight, reduction="sum") / weight.sum()
262
+ # print("hidden_proj.shape:",hidden_proj.shape)
263
+ # print("wav2vec_embeds.shape:",wav2vec_embeds.shape)
264
+ loss_cos = self.projection_loss(hidden_proj, wav2vec_embeds)
265
+ loss = loss_re + loss_cos * 0.5
266
+ # print("loss_cos:",loss_cos,loss_cos.device)
267
+ print("loss:",loss,loss.device)
268
+ # exit()
269
+ return loss, loss_re, loss_cos
270
+
271
+ class PromptCondAudioDiffusion(nn.Module):
272
+ def __init__(
273
+ self,
274
+ num_channels,
275
+ unet_model_name=None,
276
+ unet_model_config_path=None,
277
+ snr_gamma=None,
278
+ hubert_layer=None,
279
+ ssl_layer=None,
280
+ uncondition=True,
281
+ out_paint=False,
282
+ ):
283
+ super().__init__()
284
+
285
+ assert unet_model_name is not None or unet_model_config_path is not None, "Either UNet pretrain model name or a config file path is required"
286
+
287
+ self.unet_model_name = unet_model_name
288
+ self.unet_model_config_path = unet_model_config_path
289
+ self.snr_gamma = snr_gamma
290
+ self.uncondition = uncondition
291
+ self.num_channels = num_channels
292
+ self.hubert_layer = hubert_layer
293
+ self.ssl_layer = ssl_layer
294
+
295
+ # https://huggingface.co/docs/diffusers/v0.14.0/en/api/schedulers/overview
296
+ self.normfeat = Feature1DProcessor(dim=64)
297
+
298
+ self.sample_rate = 48000
299
+ self.num_samples_perseg = self.sample_rate * 20 // 1000
300
+ self.rsp48toclap = torchaudio.transforms.Resample(48000, 24000)
301
+ self.rsq48towav2vec = torchaudio.transforms.Resample(48000, 16000)
302
+ # self.wav2vec = Wav2Vec2BertModel.from_pretrained("facebook/w2v-bert-2.0", trust_remote_code=True)
303
+ # self.wav2vec_processor = AutoFeatureExtractor.from_pretrained("facebook/w2v-bert-2.0", trust_remote_code=True)
304
+ self.bestrq = load_model(
305
+ model_dir='path/to/our-MERT/mert_fairseq',
306
+ checkpoint_dir='checkpoint-120000.pt',
307
+ )
308
+ self.rsq48tobestrq = torchaudio.transforms.Resample(48000, 24000)
309
+ self.rsq48tohubert = torchaudio.transforms.Resample(48000, 16000)
310
+ for v in self.bestrq.parameters():v.requires_grad = False
311
+ self.rvq_bestrq_emb = ResidualVectorQuantize(input_dim = 1024, n_codebooks = 4, codebook_size = 16_384, codebook_dim = 32, quantizer_dropout = 0.0, stale_tolerance=200)
312
+ # for v in self.rvq_bestrq_emb.parameters():
313
+ # print(v)
314
+ freeze_parameters='quantizers.0'
315
+ for name, param in self.rvq_bestrq_emb.named_parameters():
316
+ if freeze_parameters in name:
317
+ param.requires_grad = False
318
+ print("Freezing RVQ parameters:", name)
319
+ self.hubert = HubertModelWithFinalProj.from_pretrained("huggingface_cache/models--lengyue233--content-vec-best/snapshots/c0b9ba13db21beaa4053faae94c102ebe326fd68")
320
+ for v in self.hubert.parameters():v.requires_grad = False
321
+ self.zero_cond_embedding1 = nn.Parameter(torch.randn(32*32,))
322
+ # self.xvecmodel = XVECModel()
323
+ config = GPT2Config(n_positions=1000,n_layer=39,n_head=30,n_embd=1200)
324
+ unet = GPT2Model(config)
325
+ mlp = nn.Sequential(
326
+ nn.Linear(1200, 1024),
327
+ nn.SiLU(),
328
+ nn.Linear(1024, 1024),
329
+ nn.SiLU(),
330
+ nn.Linear(1024, 768)
331
+ )
332
+ self.set_from = "random"
333
+ self.cfm_wrapper = BASECFM(unet, mlp,self.ssl_layer)
334
+ self.mask_emb = torch.nn.Embedding(3, 48)
335
+ print("Transformer initialized from pretrain.")
336
+ torch.cuda.empty_cache()
337
+ # self.unet.set_attn_processor(AttnProcessor2_0())
338
+ # self.unet.set_use_memory_efficient_attention_xformers(True)
339
+
340
+ # self.start_embedding = nn.Parameter(torch.randn(1,1024))
341
+ # self.end_embedding = nn.Parameter(torch.randn(1,1024))
342
+
343
+ def compute_snr(self, timesteps):
344
+ """
345
+ Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
346
+ """
347
+ alphas_cumprod = self.noise_scheduler.alphas_cumprod
348
+ sqrt_alphas_cumprod = alphas_cumprod**0.5
349
+ sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
350
+
351
+ # Expand the tensors.
352
+ # Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026
353
+ sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
354
+ while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape):
355
+ sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
356
+ alpha = sqrt_alphas_cumprod.expand(timesteps.shape)
357
+
358
+ sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
359
+ while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape):
360
+ sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None]
361
+ sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape)
362
+
363
+ # Compute SNR.
364
+ snr = (alpha / sigma) ** 2
365
+ return snr
366
+
367
+ def preprocess_audio(self, input_audios, threshold=0.9):
368
+ assert len(input_audios.shape) == 2, input_audios.shape
369
+ norm_value = torch.ones_like(input_audios[:,0])
370
+ max_volume = input_audios.abs().max(dim=-1)[0]
371
+ norm_value[max_volume>threshold] = max_volume[max_volume>threshold] / threshold
372
+ return input_audios/norm_value.unsqueeze(-1)
373
+
374
+ def extract_wav2vec_embeds(self, input_audios,output_len):
375
+ wav2vec_stride = 2
376
+
377
+ wav2vec_embeds = self.hubert(self.rsq48tohubert(input_audios), output_hidden_states=True).hidden_states # 1, 4096, 1024
378
+ # print(wav2vec_embeds)
379
+ # print("audio.shape:",input_audios.shape)
380
+ wav2vec_embeds_last=wav2vec_embeds[self.hubert_layer]
381
+ # print("wav2vec_embeds_last.shape:",wav2vec_embeds_last.shape)
382
+ wav2vec_embeds_last=torch.nn.functional.interpolate(wav2vec_embeds_last.permute(0, 2, 1), size=output_len, mode='linear', align_corners=False).permute(0, 2, 1)
383
+ return wav2vec_embeds_last
384
+
385
+ def extract_mert_embeds(self, input_audios):
386
+ prompt_stride = 3
387
+ inputs = self.clap_embd_extractor.mulan.audio.processor(self.rsp48toclap(input_audios), sampling_rate=self.clap_embd_extractor.mulan.audio.sr, return_tensors="pt")
388
+ input_values = inputs['input_values'].squeeze(0).to(input_audios.device, dtype = input_audios.dtype)
389
+ prompt_embeds = self.clap_embd_extractor.mulan.audio.model(input_values, output_hidden_states=True).hidden_states # batch_size, Time steps, 1024
390
+ mert_emb= prompt_embeds[-1]
391
+ mert_emb = torch.nn.functional.interpolate(mert_emb.permute(0, 2, 1), size=500, mode='linear', align_corners=False).permute(0, 2, 1)
392
+
393
+ return mert_emb
394
+
395
+ def extract_bestrq_embeds(self, input_audio_0,input_audio_1,layer):
396
+ self.bestrq.eval()
397
+ # print("audio shape:",input_audio_0.shape)
398
+ input_wav_mean = (input_audio_0 + input_audio_1) / 2.0
399
+ # print("input_wav_mean.shape:",input_wav_mean.shape)
400
+ # input_wav_mean = torch.randn(2,1720320*2).to(input_audio_0.device)
401
+ input_wav_mean = self.bestrq(self.rsq48tobestrq(input_wav_mean), features_only = True)
402
+ layer_results = input_wav_mean['layer_results']
403
+ # print("layer_results.shape:",layer_results[layer].shape)
404
+ bestrq_emb = layer_results[layer]
405
+ bestrq_emb = bestrq_emb.permute(0,2,1).contiguous()
406
+ #[b,t,1024] t=t/960
407
+ #35.84s->batch,896,1024
408
+ return bestrq_emb
409
+
410
+
411
+ def extract_spk_embeds(self, input_audios):
412
+ spk_embeds = self.xvecmodel(self.rsq48towav2vec(input_audios))
413
+ spk_embeds = self.spk_linear(spk_embeds).reshape(spk_embeds.shape[0], 16, 1, 32)
414
+ return spk_embeds
415
+
416
+ def extract_lyric_feats(self, lyric):
417
+ with torch.no_grad():
418
+ try:
419
+ text_encoder_hidden_states, text_mask, text_prompt_embeds = self.clap_embd_extractor(texts = lyric, return_one=False)
420
+ except:
421
+ text_encoder_hidden_states, text_mask, text_prompt_embeds = self.clap_embd_extractor(texts = [""] * len(lyric), return_one=False)
422
+ text_encoder_hidden_states = text_encoder_hidden_states.to(self.device)
423
+ text_mask = text_mask.to(self.device)
424
+ text_encoder_hidden_states, text_mask, text_prompt_embeds = \
425
+ pad_or_tunc_tolen(text_encoder_hidden_states, text_mask, text_prompt_embeds)
426
+ text_encoder_hidden_states = text_encoder_hidden_states.permute(0,2,1).contiguous()
427
+ return text_encoder_hidden_states, text_mask
428
+
429
+ def extract_energy_bar(self, input_audios):
430
+ if(input_audios.shape[-1] % self.num_samples_perseg > 0):
431
+ energy_bar = input_audios[:,:-1 * (input_audios.shape[-1] % self.num_samples_perseg)].reshape(input_audios.shape[0],-1,self.num_samples_perseg)
432
+ else:
433
+ energy_bar = input_audios.reshape(input_audios.shape[0],-1,self.num_samples_perseg)
434
+ energy_bar = (energy_bar.pow(2.0).mean(-1).sqrt() + 1e-6).log10() * 20 # B T
435
+ energy_bar = (energy_bar / 2.0 + 16).clamp(0,16).int()
436
+ energy_embedding = self.energy_embedding(energy_bar)
437
+ energy_embedding = energy_embedding.view(energy_embedding.shape[0], energy_embedding.shape[1] // 2, 2, 32).reshape(energy_embedding.shape[0], energy_embedding.shape[1] // 2, 64).permute(0,2,1) # b 128 t
438
+ return energy_embedding
439
+
440
+ def forward(self, input_audios, lyric, latents, latent_masks, validation_mode=False, \
441
+ additional_feats = ['spk', 'lyric'], \
442
+ train_rvq=True, train_ssl=False,layer=5):
443
+ if not hasattr(self,"device"):
444
+ self.device = input_audios.device
445
+ if not hasattr(self,"dtype"):
446
+ self.dtype = input_audios.dtype
447
+ device = self.device
448
+ input_audio_0 = input_audios[:,0,:]
449
+ input_audio_1 = input_audios[:,1,:]
450
+ input_audio_0 = self.preprocess_audio(input_audio_0)
451
+ input_audio_1 = self.preprocess_audio(input_audio_1)
452
+ input_audios_wav2vec = (input_audio_0 + input_audio_1) / 2.0
453
+ # energy_embedding = self.extract_energy_bar(input_audios)
454
+ # print("energy_embedding.shape:",energy_embedding.shape)
455
+ # with autocast(enabled=False):
456
+ if(train_ssl):
457
+ self.wav2vec.train()
458
+ wav2vec_embeds = self.extract_wav2vec_embeds(input_audios)
459
+ self.clap_embd_extractor.train()
460
+ prompt_embeds = self.extract_mert_embeds(input_audios)
461
+ if('spk' in additional_feats):
462
+ self.xvecmodel.train()
463
+ spk_embeds = self.extract_spk_embeds(input_audios).repeat(1,1,prompt_embeds.shape[-1]//2,1)
464
+ else:
465
+ with torch.no_grad():
466
+ with autocast(enabled=False):
467
+ bestrq_emb = self.extract_bestrq_embeds(input_audio_0,input_audio_1,layer)
468
+ # mert_emb = self.extract_mert_embeds(input_audios_mert)
469
+
470
+ wav2vec_embeds = self.extract_wav2vec_embeds(input_audios_wav2vec,bestrq_emb.shape[2])
471
+
472
+ bestrq_emb = bestrq_emb.detach()
473
+ if('lyric' in additional_feats):
474
+ text_encoder_hidden_states, text_mask = self.extract_lyric_feats(lyric)
475
+ else:
476
+ text_encoder_hidden_states, text_mask = None, None
477
+
478
+
479
+ if(train_rvq):
480
+ random_num=random.random()
481
+ if(random_num<0.6):
482
+ rvq_layer = 1
483
+ elif(random_num<0.8):
484
+ rvq_layer = 2
485
+ else:
486
+ rvq_layer = 4
487
+ quantized_bestrq_emb, _, _, commitment_loss_bestrq_emb, codebook_loss_bestrq_emb,_ = self.rvq_bestrq_emb(bestrq_emb,n_quantizers=rvq_layer) # b,d,t
488
+ else:
489
+ bestrq_emb = bestrq_emb.float()
490
+ self.rvq_bestrq_emb.eval()
491
+ # with autocast(enabled=False):
492
+ quantized_bestrq_emb, _, _, commitment_loss_bestrq_emb, codebook_loss_bestrq_emb,_ = self.rvq_bestrq_emb(bestrq_emb) # b,d,t
493
+ commitment_loss_bestrq_emb = commitment_loss_bestrq_emb.detach()
494
+ codebook_loss_bestrq_emb = codebook_loss_bestrq_emb.detach()
495
+ quantized_bestrq_emb = quantized_bestrq_emb.detach()
496
+
497
+ commitment_loss = commitment_loss_bestrq_emb
498
+ codebook_loss = codebook_loss_bestrq_emb
499
+
500
+
501
+ alpha=1
502
+ quantized_bestrq_emb = quantized_bestrq_emb * alpha + bestrq_emb * (1-alpha)
503
+
504
+ # print("quantized_bestrq_emb.shape:",quantized_bestrq_emb.shape)
505
+ # print("latent_masks.shape:",latent_masks.shape)
506
+ # quantized_bestrq_emb = torch.nn.functional.interpolate(quantized_bestrq_emb, size=(int(quantized_bestrq_emb.shape[-1]/999*937),), mode='linear', align_corners=True)
507
+
508
+
509
+
510
+ scenario = np.random.choice(['start_seg', 'other_seg'])
511
+ if(scenario == 'other_seg'):
512
+ for binx in range(input_audios.shape[0]):
513
+ # latent_masks[binx,0:64] = 1
514
+ latent_masks[binx,0:random.randint(64,128)] = 1
515
+ quantized_bestrq_emb = quantized_bestrq_emb.permute(0,2,1).contiguous()
516
+ # print("quantized_bestrq_emb.shape:",quantized_bestrq_emb.shape)
517
+ # print("quantized_bestrq_emb1.shape:",quantized_bestrq_emb.shape)
518
+ # print("latent_masks.shape:",latent_masks.shape)
519
+ quantized_bestrq_emb = (latent_masks > 0.5).unsqueeze(-1) * quantized_bestrq_emb \
520
+ + (latent_masks < 0.5).unsqueeze(-1) * self.zero_cond_embedding1.reshape(1,1,1024)
521
+
522
+
523
+
524
+
525
+ if self.uncondition:
526
+ mask_indices = [k for k in range(quantized_bestrq_emb.shape[0]) if random.random() < 0.1]
527
+ if len(mask_indices) > 0:
528
+ quantized_bestrq_emb[mask_indices] = 0
529
+ # print("latents.shape:",latents.shape)
530
+ latents = latents.permute(0,2,1).contiguous()
531
+ latents = self.normfeat.project_sample(latents)
532
+ latents = latents.permute(0,2,1).contiguous()
533
+ incontext_latents = latents * ((latent_masks > 0.5) * (latent_masks < 1.5)).unsqueeze(-1).float()
534
+ attention_mask=(latent_masks > 0.5)
535
+ B, L = attention_mask.size()
536
+ attention_mask = attention_mask.view(B, 1, L)
537
+ attention_mask = attention_mask * attention_mask.transpose(-1, -2)
538
+ attention_mask = attention_mask.unsqueeze(1)
539
+ # print("incontext_latents.shape:",incontext_latents.shape)
540
+ # print("quantized_bestrq_emb.shape:",quantized_bestrq_emb.shape)
541
+ latent_mask_input = self.mask_emb(latent_masks)
542
+ #64+48+64+1024
543
+ loss,loss_re, loss_cos = self.cfm_wrapper.compute_loss(latents, [latent_mask_input,incontext_latents, quantized_bestrq_emb], latent_masks,attention_mask,wav2vec_embeds, validation_mode=validation_mode)
544
+ return loss,loss_re, loss_cos, commitment_loss.mean(), codebook_loss.mean()
545
+
546
+ def init_device_dtype(self, device, dtype):
547
+ self.device = device
548
+ self.dtype = dtype
549
+
550
+ @torch.no_grad()
551
+ def fetch_codes(self, input_audios, additional_feats,layer,rvq_num=1):
552
+ input_audio_0 = input_audios[[0],:]
553
+ input_audio_1 = input_audios[[1],:]
554
+ input_audio_0 = self.preprocess_audio(input_audio_0)
555
+ input_audio_1 = self.preprocess_audio(input_audio_1)
556
+
557
+ self.bestrq.eval()
558
+
559
+ # bestrq_middle,bestrq_last = self.extract_bestrq_embeds(input_audios)
560
+ # bestrq_middle = bestrq_middle.detach()
561
+ # bestrq_last = bestrq_last.detach()
562
+ bestrq_emb = self.extract_bestrq_embeds(input_audio_0,input_audio_1,layer)
563
+ bestrq_emb = bestrq_emb.detach()
564
+
565
+ # self.rvq_bestrq_middle.eval()
566
+ # quantized_bestrq_middle, codes_bestrq_middle, *_ = self.rvq_bestrq_middle(bestrq_middle) # b,d,t
567
+ # self.rvq_bestrq_last.eval()
568
+ # quantized_bestrq_last, codes_bestrq_last, *_ = self.rvq_bestrq_last(bestrq_last) # b,d,t
569
+
570
+ self.rvq_bestrq_emb.eval()
571
+ quantized_bestrq_emb, codes_bestrq_emb, *_ = self.rvq_bestrq_emb(bestrq_emb)
572
+ codes_bestrq_emb = codes_bestrq_emb[:,:rvq_num,:]
573
+ # print("codes_bestrq_emb.shape:",codes_bestrq_emb.shape)
574
+ # exit()
575
+
576
+
577
+ if('spk' in additional_feats):
578
+ self.xvecmodel.eval()
579
+ spk_embeds = self.extract_spk_embeds(input_audios)
580
+ else:
581
+ spk_embeds = None
582
+
583
+ # return [codes_prompt, codes_wav2vec], [prompt_embeds, wav2vec_embeds], spk_embeds
584
+ # return [codes_prompt_7, codes_prompt_13, codes_prompt_20, codes_wav2vec_half, codes_wav2vec_last], [prompt_embeds_7, prompt_embeds_13, prompt_embeds_20, wav2vec_embeds_half, wav2vec_embeds_last], spk_embeds
585
+ # return [codes_bestrq_middle, codes_bestrq_last], [bestrq_middle, bestrq_last], spk_embeds
586
+ return [codes_bestrq_emb], [bestrq_emb], spk_embeds
587
+ # return [codes_prompt_13, codes_wav2vec_last], [prompt_embeds_13, wav2vec_embeds_last], spk_embeds
588
+
589
+ @torch.no_grad()
590
+ def fetch_codes_batch(self, input_audios, additional_feats,layer,rvq_num=1):
591
+ input_audio_0 = input_audios[:,0,:]
592
+ input_audio_1 = input_audios[:,1,:]
593
+ input_audio_0 = self.preprocess_audio(input_audio_0)
594
+ input_audio_1 = self.preprocess_audio(input_audio_1)
595
+
596
+ self.bestrq.eval()
597
+
598
+ # bestrq_middle,bestrq_last = self.extract_bestrq_embeds(input_audios)
599
+ # bestrq_middle = bestrq_middle.detach()
600
+ # bestrq_last = bestrq_last.detach()
601
+ bestrq_emb = self.extract_bestrq_embeds(input_audio_0,input_audio_1,layer)
602
+ bestrq_emb = bestrq_emb.detach()
603
+
604
+ # self.rvq_bestrq_middle.eval()
605
+ # quantized_bestrq_middle, codes_bestrq_middle, *_ = self.rvq_bestrq_middle(bestrq_middle) # b,d,t
606
+ # self.rvq_bestrq_last.eval()
607
+ # quantized_bestrq_last, codes_bestrq_last, *_ = self.rvq_bestrq_last(bestrq_last) # b,d,t
608
+
609
+ self.rvq_bestrq_emb.eval()
610
+ quantized_bestrq_emb, codes_bestrq_emb, *_ = self.rvq_bestrq_emb(bestrq_emb)
611
+ # print("codes_bestrq_emb.shape:",codes_bestrq_emb.shape)
612
+ codes_bestrq_emb = codes_bestrq_emb[:,:rvq_num,:]
613
+ # print("codes_bestrq_emb.shape:",codes_bestrq_emb.shape)
614
+ # exit()
615
+
616
+
617
+ if('spk' in additional_feats):
618
+ self.xvecmodel.eval()
619
+ spk_embeds = self.extract_spk_embeds(input_audios)
620
+ else:
621
+ spk_embeds = None
622
+
623
+ # return [codes_prompt, codes_wav2vec], [prompt_embeds, wav2vec_embeds], spk_embeds
624
+ # return [codes_prompt_7, codes_prompt_13, codes_prompt_20, codes_wav2vec_half, codes_wav2vec_last], [prompt_embeds_7, prompt_embeds_13, prompt_embeds_20, wav2vec_embeds_half, wav2vec_embeds_last], spk_embeds
625
+ # return [codes_bestrq_middle, codes_bestrq_last], [bestrq_middle, bestrq_last], spk_embeds
626
+ return [codes_bestrq_emb], [bestrq_emb], spk_embeds
627
+ # return [codes_prompt_13, codes_wav2vec_last], [prompt_embeds_13, wav2vec_embeds_last], spk_embeds
628
+
629
+ @torch.no_grad()
630
+ def fetch_codes_batch_ds(self, input_audios, additional_feats, layer, rvq_num=1, ds=250):
631
+ input_audio_0 = input_audios[:,0,:]
632
+ input_audio_1 = input_audios[:,1,:]
633
+ input_audio_0 = self.preprocess_audio(input_audio_0)
634
+ input_audio_1 = self.preprocess_audio(input_audio_1)
635
+
636
+ self.bestrq.eval()
637
+
638
+ # bestrq_middle,bestrq_last = self.extract_bestrq_embeds(input_audios)
639
+ # bestrq_middle = bestrq_middle.detach()
640
+ # bestrq_last = bestrq_last.detach()
641
+ bestrq_emb = self.extract_bestrq_embeds(input_audio_0,input_audio_1,layer)
642
+ bestrq_emb = bestrq_emb.detach()
643
+
644
+ # self.rvq_bestrq_middle.eval()
645
+ # quantized_bestrq_middle, codes_bestrq_middle, *_ = self.rvq_bestrq_middle(bestrq_middle) # b,d,t
646
+ # self.rvq_bestrq_last.eval()
647
+ # quantized_bestrq_last, codes_bestrq_last, *_ = self.rvq_bestrq_last(bestrq_last) # b,d,t
648
+
649
+ self.rvq_bestrq_emb.eval()
650
+ bestrq_emb = torch.nn.functional.avg_pool1d(bestrq_emb, kernel_size=ds, stride=ds)
651
+ quantized_bestrq_emb, codes_bestrq_emb, *_ = self.rvq_bestrq_emb(bestrq_emb)
652
+ # print("codes_bestrq_emb.shape:",codes_bestrq_emb.shape)
653
+ codes_bestrq_emb = codes_bestrq_emb[:,:rvq_num,:]
654
+ # print("codes_bestrq_emb.shape:",codes_bestrq_emb.shape)
655
+ # exit()
656
+
657
+
658
+ if('spk' in additional_feats):
659
+ self.xvecmodel.eval()
660
+ spk_embeds = self.extract_spk_embeds(input_audios)
661
+ else:
662
+ spk_embeds = None
663
+
664
+ # return [codes_prompt, codes_wav2vec], [prompt_embeds, wav2vec_embeds], spk_embeds
665
+ # return [codes_prompt_7, codes_prompt_13, codes_prompt_20, codes_wav2vec_half, codes_wav2vec_last], [prompt_embeds_7, prompt_embeds_13, prompt_embeds_20, wav2vec_embeds_half, wav2vec_embeds_last], spk_embeds
666
+ # return [codes_bestrq_middle, codes_bestrq_last], [bestrq_middle, bestrq_last], spk_embeds
667
+ return [codes_bestrq_emb], [bestrq_emb], spk_embeds
668
+ # return [codes_prompt_13, codes_wav2vec_last], [prompt_embeds_13, wav2vec_embeds_last], spk_embeds
669
+
670
+ @torch.no_grad()
671
+ def inference_codes(self, codes, spk_embeds, true_latents, latent_length, additional_feats, incontext_length=127,
672
+ guidance_scale=2, num_steps=20,
673
+ disable_progress=True, scenario='start_seg'):
674
+ classifier_free_guidance = guidance_scale > 1.0
675
+ device = self.device
676
+ dtype = self.dtype
677
+ # codes_bestrq_middle, codes_bestrq_last = codes
678
+ codes_bestrq_emb = codes[0]
679
+
680
+
681
+ batch_size = codes_bestrq_emb.shape[0]
682
+
683
+
684
+ quantized_bestrq_emb,_,_=self.rvq_bestrq_emb.from_codes(codes_bestrq_emb)
685
+ # quantized_bestrq_emb = torch.nn.functional.interpolate(quantized_bestrq_emb, size=(int(quantized_bestrq_emb.shape[-1]/999*937),), mode='linear', align_corners=True)
686
+ quantized_bestrq_emb = quantized_bestrq_emb.permute(0,2,1).contiguous()
687
+ print("quantized_bestrq_emb.shape:",quantized_bestrq_emb.shape)
688
+ # quantized_bestrq_emb = torch.nn.functional.interpolate(quantized_bestrq_emb, size=(int(quantized_bestrq_emb.shape[-1]/999*937),), mode='linear', align_corners=True)
689
+
690
+
691
+
692
+
693
+ if('spk' in additional_feats):
694
+ spk_embeds = spk_embeds.repeat(1,1,quantized_bestrq_emb.shape[-2],1).detach()
695
+
696
+ num_frames = quantized_bestrq_emb.shape[1]
697
+
698
+ num_channels_latents = self.num_channels
699
+ shape = (batch_size, num_frames, 64)
700
+ latents = randn_tensor(shape, generator=None, device=device, dtype=dtype)
701
+
702
+
703
+
704
+ latent_masks = torch.zeros(latents.shape[0], latents.shape[1], dtype=torch.int64, device=latents.device)
705
+ latent_masks[:,0:latent_length] = 2
706
+ if(scenario=='other_seg'):
707
+ latent_masks[:,0:incontext_length] = 1
708
+
709
+
710
+
711
+ quantized_bestrq_emb = (latent_masks > 0.5).unsqueeze(-1) * quantized_bestrq_emb \
712
+ + (latent_masks < 0.5).unsqueeze(-1) * self.zero_cond_embedding1.reshape(1,1,1024)
713
+ true_latents = true_latents.permute(0,2,1).contiguous()
714
+ true_latents = self.normfeat.project_sample(true_latents)
715
+ true_latents = true_latents.permute(0,2,1).contiguous()
716
+ incontext_latents = true_latents * ((latent_masks > 0.5) * (latent_masks < 1.5)).unsqueeze(-1).float()
717
+ incontext_length = ((latent_masks > 0.5) * (latent_masks < 1.5)).sum(-1)[0]
718
+
719
+
720
+ attention_mask=(latent_masks > 0.5)
721
+ B, L = attention_mask.size()
722
+ attention_mask = attention_mask.view(B, 1, L)
723
+ attention_mask = attention_mask * attention_mask.transpose(-1, -2)
724
+ attention_mask = attention_mask.unsqueeze(1)
725
+ latent_mask_input = self.mask_emb(latent_masks)
726
+
727
+ if('spk' in additional_feats):
728
+ # additional_model_input = torch.cat([quantized_bestrq_middle, quantized_bestrq_last, spk_embeds],1)
729
+ additional_model_input = torch.cat([quantized_bestrq_emb, spk_embeds],1)
730
+ else:
731
+ # additional_model_input = torch.cat([quantized_bestrq_middle, quantized_bestrq_last],1)
732
+ additional_model_input = torch.cat([quantized_bestrq_emb],1)
733
+
734
+ temperature = 1.0
735
+ t_span = torch.linspace(0, 1, num_steps + 1, device=quantized_bestrq_emb.device)
736
+ latents = self.cfm_wrapper.solve_euler(latents * temperature, latent_mask_input,incontext_latents, incontext_length, t_span, additional_model_input,attention_mask, guidance_scale)
737
+
738
+ latents[:,0:incontext_length,:] = incontext_latents[:,0:incontext_length,:]
739
+ latents = latents.permute(0,2,1).contiguous()
740
+ latents = self.normfeat.return_sample(latents)
741
+ # latents = latents.permute(0,2,1).contiguous()
742
+ return latents
743
+
744
+ @torch.no_grad()
745
+ def inference(self, input_audios, lyric, true_latents, latent_length, additional_feats, guidance_scale=2, num_steps=20,
746
+ disable_progress=True,layer=5,scenario='start_seg',rvq_num=1):
747
+ codes, embeds, spk_embeds = self.fetch_codes(input_audios, additional_feats,layer,rvq_num)
748
+
749
+ latents = self.inference_codes(codes, spk_embeds, true_latents, latent_length, additional_feats, \
750
+ guidance_scale=guidance_scale, num_steps=num_steps, \
751
+ disable_progress=disable_progress,scenario=scenario)
752
+ return latents
753
+
754
+ @torch.no_grad()
755
+ def inference_rtf(self, input_audios, lyric, true_latents, latent_length, additional_feats, guidance_scale=2, num_steps=20,
756
+ disable_progress=True,layer=5,scenario='start_seg'):
757
+ codes, embeds, spk_embeds = self.fetch_codes(input_audios, additional_feats,layer)
758
+ import time
759
+ start = time.time()
760
+ latents = self.inference_codes(codes, spk_embeds, true_latents, latent_length, additional_feats, \
761
+ guidance_scale=guidance_scale, num_steps=num_steps, \
762
+ disable_progress=disable_progress,scenario=scenario)
763
+ return latents,time.time()-start
764
+
765
+ def prepare_latents(self, batch_size, num_frames, num_channels_latents, dtype, device):
766
+ divisor = 4
767
+ shape = (batch_size, num_channels_latents, num_frames, 32)
768
+ if(num_frames%divisor>0):
769
+ num_frames = round(num_frames/float(divisor))*divisor
770
+ shape = (batch_size, num_channels_latents, num_frames, 32)
771
+ latents = randn_tensor(shape, generator=None, device=device, dtype=dtype)
772
+ return latents
773
+
774
+
codeclm/tokenizer/Flow1dVAE/model_septoken.py CHANGED
@@ -1,670 +1,670 @@
1
- import yaml
2
- import random
3
- import inspect
4
- import numpy as np
5
- from tqdm import tqdm
6
- import typing as tp
7
- from abc import ABC
8
-
9
- import torch
10
- import torch.nn as nn
11
- import torch.nn.functional as F
12
- import torchaudio
13
-
14
- from einops import repeat
15
- from tools.torch_tools import wav_to_fbank
16
-
17
- from diffusers.utils.torch_utils import randn_tensor
18
- from transformers import HubertModel
19
- from libs.rvq.descript_quantize3 import ResidualVectorQuantize
20
-
21
- from models_gpt.models.gpt2_rope2_time_new_correct_mask_noncasual_reflow import GPT2Model
22
- from models_gpt.models.gpt2_config import GPT2Config
23
-
24
- from torch.cuda.amp import autocast
25
- from our_MERT_BESTRQ.test import load_model
26
-
27
- class HubertModelWithFinalProj(HubertModel):
28
- def __init__(self, config):
29
- super().__init__(config)
30
-
31
- # The final projection layer is only used for backward compatibility.
32
- # Following https://github.com/auspicious3000/contentvec/issues/6
33
- # Remove this layer is necessary to achieve the desired outcome.
34
- print("hidden_size:",config.hidden_size)
35
- print("classifier_proj_size:",config.classifier_proj_size)
36
- self.final_proj = nn.Linear(config.hidden_size, config.classifier_proj_size)
37
-
38
-
39
- class SampleProcessor(torch.nn.Module):
40
- def project_sample(self, x: torch.Tensor):
41
- """Project the original sample to the 'space' where the diffusion will happen."""
42
- """Project back from diffusion space to the actual sample space."""
43
- return z
44
-
45
- class Feature1DProcessor(SampleProcessor):
46
- def __init__(self, dim: int = 100, power_std = 1., \
47
- num_samples: int = 100_000, cal_num_frames: int = 600):
48
- super().__init__()
49
-
50
- self.num_samples = num_samples
51
- self.dim = dim
52
- self.power_std = power_std
53
- self.cal_num_frames = cal_num_frames
54
- self.register_buffer('counts', torch.zeros(1))
55
- self.register_buffer('sum_x', torch.zeros(dim))
56
- self.register_buffer('sum_x2', torch.zeros(dim))
57
- self.register_buffer('sum_target_x2', torch.zeros(dim))
58
- self.counts: torch.Tensor
59
- self.sum_x: torch.Tensor
60
- self.sum_x2: torch.Tensor
61
-
62
- @property
63
- def mean(self):
64
- mean = self.sum_x / self.counts
65
- if(self.counts < 10):
66
- mean = torch.zeros_like(mean)
67
- return mean
68
-
69
- @property
70
- def std(self):
71
- std = (self.sum_x2 / self.counts - self.mean**2).clamp(min=0).sqrt()
72
- if(self.counts < 10):
73
- std = torch.ones_like(std)
74
- return std
75
-
76
- @property
77
- def target_std(self):
78
- return 1
79
-
80
- def project_sample(self, x: torch.Tensor):
81
- assert x.dim() == 3
82
- if self.counts.item() < self.num_samples:
83
- self.counts += len(x)
84
- self.sum_x += x[:,:,0:self.cal_num_frames].mean(dim=(2,)).sum(dim=0)
85
- self.sum_x2 += x[:,:,0:self.cal_num_frames].pow(2).mean(dim=(2,)).sum(dim=0)
86
- rescale = (self.target_std / self.std.clamp(min=1e-12)) ** self.power_std # same output size
87
- x = (x - self.mean.view(1, -1, 1)) * rescale.view(1, -1, 1)
88
- return x
89
-
90
- def return_sample(self, x: torch.Tensor):
91
- assert x.dim() == 3
92
- rescale = (self.std / self.target_std) ** self.power_std
93
- x = x * rescale.view(1, -1, 1) + self.mean.view(1, -1, 1)
94
- return x
95
-
96
- def pad_or_tunc_tolen(prior_text_encoder_hidden_states, prior_text_mask, prior_prompt_embeds, len_size=77):
97
- if(prior_text_encoder_hidden_states.shape[1]<len_size):
98
- prior_text_encoder_hidden_states = torch.cat([prior_text_encoder_hidden_states, \
99
- torch.zeros(prior_text_mask.shape[0], len_size-prior_text_mask.shape[1], \
100
- prior_text_encoder_hidden_states.shape[2], device=prior_text_mask.device, \
101
- dtype=prior_text_encoder_hidden_states.dtype)],1)
102
- prior_text_mask = torch.cat([prior_text_mask, torch.zeros(prior_text_mask.shape[0], len_size-prior_text_mask.shape[1], device=prior_text_mask.device, dtype=prior_text_mask.dtype)],1)
103
- else:
104
- prior_text_encoder_hidden_states = prior_text_encoder_hidden_states[:,0:len_size]
105
- prior_text_mask = prior_text_mask[:,0:len_size]
106
- prior_text_encoder_hidden_states = prior_text_encoder_hidden_states.permute(0,2,1).contiguous()
107
- return prior_text_encoder_hidden_states, prior_text_mask, prior_prompt_embeds
108
-
109
- class BASECFM(torch.nn.Module, ABC):
110
- def __init__(
111
- self,
112
- estimator,
113
- mlp
114
- ):
115
- super().__init__()
116
- self.sigma_min = 1e-4
117
-
118
- self.estimator = estimator
119
- self.mlp = mlp
120
-
121
- @torch.inference_mode()
122
- def forward(self, mu, n_timesteps, temperature=1.0):
123
- """Forward diffusion
124
-
125
- Args:
126
- mu (torch.Tensor): output of encoder
127
- shape: (batch_size, n_channels, mel_timesteps, n_feats)
128
- n_timesteps (int): number of diffusion steps
129
- temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
130
-
131
- Returns:
132
- sample: generated mel-spectrogram
133
- shape: (batch_size, n_channels, mel_timesteps, n_feats)
134
- """
135
- z = torch.randn_like(mu) * temperature
136
- t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device)
137
- return self.solve_euler(z, t_span=t_span)
138
-
139
- def solve_euler(self, x, latent_mask_input,incontext_x, incontext_length, t_span, mu,attention_mask, guidance_scale):
140
- """
141
- Fixed euler solver for ODEs.
142
- Args:
143
- x (torch.Tensor): random noise
144
- t_span (torch.Tensor): n_timesteps interpolated
145
- shape: (n_timesteps + 1,)
146
- mu (torch.Tensor): output of encoder
147
- shape: (batch_size, n_channels, mel_timesteps, n_feats)
148
- """
149
- t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
150
- noise = x.clone()
151
-
152
- # I am storing this because I can later plot it by putting a debugger here and saving it to a file
153
- # Or in future might add like a return_all_steps flag
154
- sol = []
155
-
156
- for step in tqdm(range(1, len(t_span))):
157
- x[:,0:incontext_length,:] = (1 - (1 - self.sigma_min) * t) * noise[:,0:incontext_length,:] + t * incontext_x[:,0:incontext_length,:]
158
- if(guidance_scale > 1.0):
159
-
160
- model_input = torch.cat([ \
161
- torch.cat([latent_mask_input, latent_mask_input], 0), \
162
- torch.cat([incontext_x, incontext_x], 0), \
163
- torch.cat([torch.zeros_like(mu), mu], 0), \
164
- torch.cat([x, x], 0), \
165
- ], 2)
166
- timestep=t.unsqueeze(-1).repeat(2)
167
-
168
- dphi_dt = self.estimator(inputs_embeds=model_input, attention_mask=attention_mask,time_step=timestep).last_hidden_state
169
- dphi_dt_uncond, dhpi_dt_cond = dphi_dt.chunk(2,0)
170
- dphi_dt = dphi_dt_uncond + guidance_scale * (dhpi_dt_cond - dphi_dt_uncond)
171
- else:
172
- model_input = torch.cat([latent_mask_input, incontext_x, mu, x], 2)
173
- timestep=t.unsqueeze(-1)
174
- dphi_dt = self.estimator(inputs_embeds=model_input, attention_mask=attention_mask,time_step=timestep).last_hidden_state
175
-
176
- dphi_dt = dphi_dt[: ,:, -x.shape[2]:]
177
- x = x + dt * dphi_dt
178
- t = t + dt
179
- sol.append(x)
180
- if step < len(t_span) - 1:
181
- dt = t_span[step + 1] - t
182
-
183
- return sol[-1]
184
-
185
- def projection_loss(self,hidden_proj, bestrq_emb):
186
- bsz = hidden_proj.shape[0]
187
-
188
- hidden_proj_normalized = F.normalize(hidden_proj, dim=-1)
189
- bestrq_emb_normalized = F.normalize(bestrq_emb, dim=-1)
190
-
191
- proj_loss = -(hidden_proj_normalized * bestrq_emb_normalized).sum(dim=-1)
192
- proj_loss = 1+proj_loss.mean()
193
-
194
- return proj_loss
195
-
196
- def compute_loss(self, x1, mu, latent_masks,attention_mask,wav2vec_embeds, validation_mode=False):
197
- """Computes diffusion loss
198
-
199
- Args:
200
- x1 (torch.Tensor): Target
201
- shape: (batch_size, n_channels, mel_timesteps, n_feats)
202
- mu (torch.Tensor): output of encoder
203
- shape: (batch_size, n_channels, mel_timesteps, n_feats)
204
-
205
- Returns:
206
- loss: conditional flow matching loss
207
- y: conditional flow
208
- shape: (batch_size, n_channels, mel_timesteps, n_feats)
209
- """
210
- b = mu[0].shape[0]
211
- len_x = x1.shape[2]
212
- # random timestep
213
- if(validation_mode):
214
- t = torch.ones([b, 1, 1], device=mu[0].device, dtype=mu[0].dtype) * 0.5
215
- else:
216
- t = torch.rand([b, 1, 1], device=mu[0].device, dtype=mu[0].dtype)
217
- # sample noise p(x_0)
218
- z = torch.randn_like(x1)
219
-
220
- y = (1 - (1 - self.sigma_min) * t) * z + t * x1
221
- u = x1 - (1 - self.sigma_min) * z
222
- model_input = torch.cat([*mu,y], 2)
223
- t=t.squeeze(-1).squeeze(-1)
224
- out = self.estimator(inputs_embeds=model_input, attention_mask=attention_mask,time_step=t,output_hidden_states=True)
225
- hidden_layer_7 = out.hidden_states[7]
226
- hidden_proj = self.mlp(hidden_layer_7)
227
- out = out.last_hidden_state
228
- out=out[:,:,-len_x:]
229
-
230
- weight = (latent_masks > 1.5).unsqueeze(-1).repeat(1, 1, out.shape[-1]).float() + (latent_masks < 0.5).unsqueeze(-1).repeat(1, 1, out.shape[-1]).float() * 0.01
231
- loss_re = F.mse_loss(out * weight, u * weight, reduction="sum") / weight.sum()
232
- loss_cos = self.projection_loss(hidden_proj, wav2vec_embeds)
233
- loss = loss_re + loss_cos * 0.5
234
- return loss, loss_re, loss_cos
235
-
236
- class PromptCondAudioDiffusion(nn.Module):
237
- def __init__(
238
- self,
239
- num_channels,
240
- unet_model_name=None,
241
- unet_model_config_path=None,
242
- snr_gamma=None,
243
- uncondition=True,
244
- out_paint=False,
245
- ):
246
- super().__init__()
247
-
248
- assert unet_model_name is not None or unet_model_config_path is not None, "Either UNet pretrain model name or a config file path is required"
249
-
250
- self.unet_model_name = unet_model_name
251
- self.unet_model_config_path = unet_model_config_path
252
- self.snr_gamma = snr_gamma
253
- self.uncondition = uncondition
254
- self.num_channels = num_channels
255
-
256
- # https://huggingface.co/docs/diffusers/v0.14.0/en/api/schedulers/overview
257
- self.normfeat = Feature1DProcessor(dim=64)
258
-
259
- self.sample_rate = 48000
260
- self.num_samples_perseg = self.sample_rate * 20 // 1000
261
- self.rsp48toclap = torchaudio.transforms.Resample(48000, 24000)
262
- self.rsq48towav2vec = torchaudio.transforms.Resample(48000, 16000)
263
- # self.wav2vec = Wav2Vec2BertModel.from_pretrained("facebook/w2v-bert-2.0", trust_remote_code=True)
264
- # self.wav2vec_processor = AutoFeatureExtractor.from_pretrained("facebook/w2v-bert-2.0", trust_remote_code=True)
265
- self.bestrq = load_model(
266
- model_dir='codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq',
267
- checkpoint_dir='ckpt/encode-s12k.pt',
268
- )
269
- self.rsq48tobestrq = torchaudio.transforms.Resample(48000, 24000)
270
- self.rsq48tohubert = torchaudio.transforms.Resample(48000, 16000)
271
- for v in self.bestrq.parameters():v.requires_grad = False
272
- self.rvq_bestrq_emb = ResidualVectorQuantize(input_dim = 1024, n_codebooks = 1, codebook_size = 16_384, codebook_dim = 32, quantizer_dropout = 0.0, stale_tolerance=200)
273
- self.rvq_bestrq_bgm_emb = ResidualVectorQuantize(input_dim = 1024, n_codebooks = 1, codebook_size = 16_384, codebook_dim = 32, quantizer_dropout = 0.0, stale_tolerance=200)
274
- self.hubert = HubertModelWithFinalProj.from_pretrained("ckpt/models--lengyue233--content-vec-best/snapshots/c0b9ba13db21beaa4053faae94c102ebe326fd68")
275
- for v in self.hubert.parameters():v.requires_grad = False
276
- self.zero_cond_embedding1 = nn.Parameter(torch.randn(32*32,))
277
- # self.xvecmodel = XVECModel()
278
- config = GPT2Config(n_positions=1000,n_layer=16,n_head=20,n_embd=2200,n_inner=4400)
279
- unet = GPT2Model(config)
280
- mlp = nn.Sequential(
281
- nn.Linear(2200, 1024),
282
- nn.SiLU(),
283
- nn.Linear(1024, 1024),
284
- nn.SiLU(),
285
- nn.Linear(1024, 768)
286
- )
287
- self.set_from = "random"
288
- self.cfm_wrapper = BASECFM(unet, mlp)
289
- self.mask_emb = torch.nn.Embedding(3, 24)
290
- print("Transformer initialized from pretrain.")
291
- torch.cuda.empty_cache()
292
-
293
- def compute_snr(self, timesteps):
294
- """
295
- Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
296
- """
297
- alphas_cumprod = self.noise_scheduler.alphas_cumprod
298
- sqrt_alphas_cumprod = alphas_cumprod**0.5
299
- sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
300
-
301
- # Expand the tensors.
302
- # Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026
303
- sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
304
- while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape):
305
- sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
306
- alpha = sqrt_alphas_cumprod.expand(timesteps.shape)
307
-
308
- sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
309
- while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape):
310
- sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None]
311
- sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape)
312
-
313
- # Compute SNR.
314
- snr = (alpha / sigma) ** 2
315
- return snr
316
-
317
- def preprocess_audio(self, input_audios, threshold=0.9):
318
- assert len(input_audios.shape) == 2, input_audios.shape
319
- norm_value = torch.ones_like(input_audios[:,0])
320
- max_volume = input_audios.abs().max(dim=-1)[0]
321
- norm_value[max_volume>threshold] = max_volume[max_volume>threshold] / threshold
322
- return input_audios/norm_value.unsqueeze(-1)
323
-
324
- def extract_wav2vec_embeds(self, input_audios,output_len):
325
- wav2vec_stride = 2
326
-
327
- wav2vec_embeds = self.hubert(self.rsq48tohubert(input_audios), output_hidden_states=True).hidden_states # 1, 4096, 1024
328
- wav2vec_embeds_last=wav2vec_embeds[-1]
329
- wav2vec_embeds_last=torch.nn.functional.interpolate(wav2vec_embeds_last.permute(0, 2, 1), size=output_len, mode='linear', align_corners=False).permute(0, 2, 1)
330
- return wav2vec_embeds_last
331
-
332
- def extract_mert_embeds(self, input_audios):
333
- prompt_stride = 3
334
- inputs = self.clap_embd_extractor.mulan.audio.processor(self.rsp48toclap(input_audios), sampling_rate=self.clap_embd_extractor.mulan.audio.sr, return_tensors="pt")
335
- input_values = inputs['input_values'].squeeze(0).to(input_audios.device, dtype = input_audios.dtype)
336
- prompt_embeds = self.clap_embd_extractor.mulan.audio.model(input_values, output_hidden_states=True).hidden_states # batch_size, Time steps, 1024
337
- mert_emb= prompt_embeds[-1]
338
- mert_emb = torch.nn.functional.interpolate(mert_emb.permute(0, 2, 1), size=375, mode='linear', align_corners=False).permute(0, 2, 1)
339
-
340
- return mert_emb
341
-
342
- def extract_bestrq_embeds(self, input_audio_vocal_0,input_audio_vocal_1,layer):
343
- input_wav_mean = (input_audio_vocal_0 + input_audio_vocal_1) / 2.0
344
- input_wav_mean = self.bestrq(self.rsq48tobestrq(input_wav_mean), features_only = True)
345
- layer_results = input_wav_mean['layer_results']
346
- bestrq_emb = layer_results[layer]
347
- bestrq_emb = bestrq_emb.permute(0,2,1).contiguous()
348
- return bestrq_emb
349
-
350
-
351
- def extract_spk_embeds(self, input_audios):
352
- spk_embeds = self.xvecmodel(self.rsq48towav2vec(input_audios))
353
- spk_embeds = self.spk_linear(spk_embeds).reshape(spk_embeds.shape[0], 16, 1, 32)
354
- return spk_embeds
355
-
356
- def extract_lyric_feats(self, lyric):
357
- with torch.no_grad():
358
- try:
359
- text_encoder_hidden_states, text_mask, text_prompt_embeds = self.clap_embd_extractor(texts = lyric, return_one=False)
360
- except:
361
- text_encoder_hidden_states, text_mask, text_prompt_embeds = self.clap_embd_extractor(texts = [""] * len(lyric), return_one=False)
362
- text_encoder_hidden_states = text_encoder_hidden_states.to(self.device)
363
- text_mask = text_mask.to(self.device)
364
- text_encoder_hidden_states, text_mask, text_prompt_embeds = \
365
- pad_or_tunc_tolen(text_encoder_hidden_states, text_mask, text_prompt_embeds)
366
- text_encoder_hidden_states = text_encoder_hidden_states.permute(0,2,1).contiguous()
367
- return text_encoder_hidden_states, text_mask
368
-
369
- def extract_energy_bar(self, input_audios):
370
- if(input_audios.shape[-1] % self.num_samples_perseg > 0):
371
- energy_bar = input_audios[:,:-1 * (input_audios.shape[-1] % self.num_samples_perseg)].reshape(input_audios.shape[0],-1,self.num_samples_perseg)
372
- else:
373
- energy_bar = input_audios.reshape(input_audios.shape[0],-1,self.num_samples_perseg)
374
- energy_bar = (energy_bar.pow(2.0).mean(-1).sqrt() + 1e-6).log10() * 20 # B T
375
- energy_bar = (energy_bar / 2.0 + 16).clamp(0,16).int()
376
- energy_embedding = self.energy_embedding(energy_bar)
377
- energy_embedding = energy_embedding.view(energy_embedding.shape[0], energy_embedding.shape[1] // 2, 2, 32).reshape(energy_embedding.shape[0], energy_embedding.shape[1] // 2, 64).permute(0,2,1) # b 128 t
378
- return energy_embedding
379
-
380
- def forward(self, input_audios_vocal,input_audios_bgm, lyric, latents, latent_masks, validation_mode=False, \
381
- additional_feats = ['spk', 'lyric'], \
382
- train_rvq=True, train_ssl=False,layer_vocal=7,layer_bgm=7):
383
- if not hasattr(self,"device"):
384
- self.device = input_audios_vocal.device
385
- if not hasattr(self,"dtype"):
386
- self.dtype = input_audios_vocal.dtype
387
- device = self.device
388
- input_audio_vocal_0 = input_audios_vocal[:,0,:]
389
- input_audio_vocal_1 = input_audios_vocal[:,1,:]
390
- input_audio_vocal_0 = self.preprocess_audio(input_audio_vocal_0)
391
- input_audio_vocal_1 = self.preprocess_audio(input_audio_vocal_1)
392
- input_audios_vocal_wav2vec = (input_audio_vocal_0 + input_audio_vocal_1) / 2.0
393
-
394
- input_audio_bgm_0 = input_audios_bgm[:,0,:]
395
- input_audio_bgm_1 = input_audios_bgm[:,1,:]
396
- input_audio_bgm_0 = self.preprocess_audio(input_audio_bgm_0)
397
- input_audio_bgm_1 = self.preprocess_audio(input_audio_bgm_1)
398
- input_audios_bgm_wav2vec = (input_audio_bgm_0 + input_audio_bgm_1) / 2.0
399
-
400
- if(train_ssl):
401
- self.wav2vec.train()
402
- wav2vec_embeds = self.extract_wav2vec_embeds(input_audios)
403
- self.clap_embd_extractor.train()
404
- prompt_embeds = self.extract_mert_embeds(input_audios)
405
- if('spk' in additional_feats):
406
- self.xvecmodel.train()
407
- spk_embeds = self.extract_spk_embeds(input_audios).repeat(1,1,prompt_embeds.shape[-1]//2,1)
408
- else:
409
- with torch.no_grad():
410
- with autocast(enabled=False):
411
- bestrq_emb = self.extract_bestrq_embeds(input_audio_vocal_0,input_audio_vocal_1,layer_vocal)
412
- bestrq_emb_bgm = self.extract_bestrq_embeds(input_audio_bgm_0,input_audio_bgm_1,layer_bgm)
413
- # mert_emb = self.extract_mert_embeds(input_audios_mert)
414
- output_len = bestrq_emb.shape[2]
415
- wav2vec_embeds = self.extract_wav2vec_embeds(input_audios_vocal_wav2vec+input_audios_bgm_wav2vec,output_len)
416
-
417
-
418
- bestrq_emb = bestrq_emb.detach()
419
- bestrq_emb_bgm = bestrq_emb_bgm.detach()
420
-
421
- if('lyric' in additional_feats):
422
- text_encoder_hidden_states, text_mask = self.extract_lyric_feats(lyric)
423
- else:
424
- text_encoder_hidden_states, text_mask = None, None
425
-
426
-
427
- if(train_rvq):
428
- quantized_bestrq_emb, _, _, commitment_loss_bestrq_emb, codebook_loss_bestrq_emb,_ = self.rvq_bestrq_emb(bestrq_emb) # b,d,t
429
- quantized_bestrq_emb_bgm, _, _, commitment_loss_bestrq_emb_bgm, codebook_loss_bestrq_emb_bgm,_ = self.rvq_bestrq_bgm_emb(bestrq_emb_bgm) # b,d,t
430
- else:
431
- bestrq_emb = bestrq_emb.float()
432
- self.rvq_bestrq_emb.eval()
433
- # with autocast(enabled=False):
434
- quantized_bestrq_emb, _, _, commitment_loss_bestrq_emb, codebook_loss_bestrq_emb,_ = self.rvq_bestrq_emb(bestrq_emb) # b,d,t
435
- commitment_loss_bestrq_emb = commitment_loss_bestrq_emb.detach()
436
- codebook_loss_bestrq_emb = codebook_loss_bestrq_emb.detach()
437
- quantized_bestrq_emb = quantized_bestrq_emb.detach()
438
-
439
- commitment_loss = commitment_loss_bestrq_emb+commitment_loss_bestrq_emb_bgm
440
- codebook_loss = codebook_loss_bestrq_emb+codebook_loss_bestrq_emb_bgm
441
-
442
-
443
- alpha=1
444
- quantized_bestrq_emb = quantized_bestrq_emb * alpha + bestrq_emb * (1-alpha)
445
- quantized_bestrq_emb_bgm = quantized_bestrq_emb_bgm * alpha + bestrq_emb_bgm * (1-alpha)
446
-
447
-
448
-
449
-
450
- scenario = np.random.choice(['start_seg', 'other_seg'])
451
- if(scenario == 'other_seg'):
452
- for binx in range(input_audios_vocal.shape[0]):
453
- # latent_masks[binx,0:64] = 1
454
- latent_masks[binx,0:random.randint(64,128)] = 1
455
- quantized_bestrq_emb = quantized_bestrq_emb.permute(0,2,1).contiguous()
456
- quantized_bestrq_emb_bgm = quantized_bestrq_emb_bgm.permute(0,2,1).contiguous()
457
- quantized_bestrq_emb = (latent_masks > 0.5).unsqueeze(-1) * quantized_bestrq_emb \
458
- + (latent_masks < 0.5).unsqueeze(-1) * self.zero_cond_embedding1.reshape(1,1,1024)
459
- quantized_bestrq_emb_bgm = (latent_masks > 0.5).unsqueeze(-1) * quantized_bestrq_emb_bgm \
460
- + (latent_masks < 0.5).unsqueeze(-1) * self.zero_cond_embedding1.reshape(1,1,1024)
461
-
462
-
463
-
464
-
465
- if self.uncondition:
466
- mask_indices = [k for k in range(quantized_bestrq_emb.shape[0]) if random.random() < 0.1]
467
- if len(mask_indices) > 0:
468
- quantized_bestrq_emb[mask_indices] = 0
469
- quantized_bestrq_emb_bgm[mask_indices] = 0
470
- latents = latents.permute(0,2,1).contiguous()
471
- latents = self.normfeat.project_sample(latents)
472
- latents = latents.permute(0,2,1).contiguous()
473
- incontext_latents = latents * ((latent_masks > 0.5) * (latent_masks < 1.5)).unsqueeze(-1).float()
474
- attention_mask=(latent_masks > 0.5)
475
- B, L = attention_mask.size()
476
- attention_mask = attention_mask.view(B, 1, L)
477
- attention_mask = attention_mask * attention_mask.transpose(-1, -2)
478
- attention_mask = attention_mask.unsqueeze(1)
479
- latent_mask_input = self.mask_emb(latent_masks)
480
- loss,loss_re, loss_cos = self.cfm_wrapper.compute_loss(latents, [latent_mask_input,incontext_latents, quantized_bestrq_emb,quantized_bestrq_emb_bgm], latent_masks,attention_mask,wav2vec_embeds, validation_mode=validation_mode)
481
- return loss,loss_re, loss_cos, commitment_loss.mean(), codebook_loss.mean()
482
-
483
- def init_device_dtype(self, device, dtype):
484
- self.device = device
485
- self.dtype = dtype
486
-
487
- @torch.no_grad()
488
- def fetch_codes(self, input_audios_vocal,input_audios_bgm, additional_feats,layer_vocal=7,layer_bgm=7):
489
- input_audio_vocal_0 = input_audios_vocal[[0],:]
490
- input_audio_vocal_1 = input_audios_vocal[[1],:]
491
- input_audio_vocal_0 = self.preprocess_audio(input_audio_vocal_0)
492
- input_audio_vocal_1 = self.preprocess_audio(input_audio_vocal_1)
493
- input_audios_vocal_wav2vec = (input_audio_vocal_0 + input_audio_vocal_1) / 2.0
494
-
495
- input_audio_bgm_0 = input_audios_bgm[[0],:]
496
- input_audio_bgm_1 = input_audios_bgm[[1],:]
497
- input_audio_bgm_0 = self.preprocess_audio(input_audio_bgm_0)
498
- input_audio_bgm_1 = self.preprocess_audio(input_audio_bgm_1)
499
- input_audios_bgm_wav2vec = (input_audio_bgm_0 + input_audio_bgm_1) / 2.0
500
-
501
- self.bestrq.eval()
502
-
503
- # bestrq_middle,bestrq_last = self.extract_bestrq_embeds(input_audios)
504
- # bestrq_middle = bestrq_middle.detach()
505
- # bestrq_last = bestrq_last.detach()
506
- bestrq_emb = self.extract_bestrq_embeds(input_audio_vocal_0,input_audio_vocal_1,layer_vocal)
507
- bestrq_emb = bestrq_emb.detach()
508
-
509
- bestrq_emb_bgm = self.extract_bestrq_embeds(input_audio_bgm_0,input_audio_bgm_1,layer_bgm)
510
- bestrq_emb_bgm = bestrq_emb_bgm.detach()
511
-
512
-
513
-
514
- self.rvq_bestrq_emb.eval()
515
- quantized_bestrq_emb, codes_bestrq_emb, *_ = self.rvq_bestrq_emb(bestrq_emb) # b,d,t
516
-
517
- self.rvq_bestrq_bgm_emb.eval()
518
- quantized_bestrq_emb_bgm, codes_bestrq_emb_bgm, *_ = self.rvq_bestrq_bgm_emb(bestrq_emb_bgm) # b,d,t
519
-
520
-
521
- if('spk' in additional_feats):
522
- self.xvecmodel.eval()
523
- spk_embeds = self.extract_spk_embeds(input_audios)
524
- else:
525
- spk_embeds = None
526
-
527
- # return [codes_prompt, codes_wav2vec], [prompt_embeds, wav2vec_embeds], spk_embeds
528
- # return [codes_prompt_7, codes_prompt_13, codes_prompt_20, codes_wav2vec_half, codes_wav2vec_last], [prompt_embeds_7, prompt_embeds_13, prompt_embeds_20, wav2vec_embeds_half, wav2vec_embeds_last], spk_embeds
529
- # return [codes_bestrq_middle, codes_bestrq_last], [bestrq_middle, bestrq_last], spk_embeds
530
- return [codes_bestrq_emb,codes_bestrq_emb_bgm], [bestrq_emb,bestrq_emb_bgm], spk_embeds
531
- # return [codes_prompt_13, codes_wav2vec_last], [prompt_embeds_13, wav2vec_embeds_last], spk_embeds
532
-
533
- @torch.no_grad()
534
- def fetch_codes_batch(self, input_audios_vocal, input_audios_bgm, additional_feats,layer_vocal=7,layer_bgm=7):
535
- input_audio_vocal_0 = input_audios_vocal[:,0,:]
536
- input_audio_vocal_1 = input_audios_vocal[:,1,:]
537
- input_audio_vocal_0 = self.preprocess_audio(input_audio_vocal_0)
538
- input_audio_vocal_1 = self.preprocess_audio(input_audio_vocal_1)
539
- input_audios_vocal_wav2vec = (input_audio_vocal_0 + input_audio_vocal_1) / 2.0
540
-
541
- input_audio_bgm_0 = input_audios_bgm[:,0,:]
542
- input_audio_bgm_1 = input_audios_bgm[:,1,:]
543
- input_audio_bgm_0 = self.preprocess_audio(input_audio_bgm_0)
544
- input_audio_bgm_1 = self.preprocess_audio(input_audio_bgm_1)
545
- input_audios_bgm_wav2vec = (input_audio_bgm_0 + input_audio_bgm_1) / 2.0
546
-
547
- self.bestrq.eval()
548
-
549
- # bestrq_middle,bestrq_last = self.extract_bestrq_embeds(input_audios)
550
- # bestrq_middle = bestrq_middle.detach()
551
- # bestrq_last = bestrq_last.detach()
552
- bestrq_emb = self.extract_bestrq_embeds(input_audio_vocal_0,input_audio_vocal_1,layer_vocal)
553
- bestrq_emb = bestrq_emb.detach()
554
-
555
- bestrq_emb_bgm = self.extract_bestrq_embeds(input_audio_bgm_0,input_audio_bgm_1,layer_bgm)
556
- bestrq_emb_bgm = bestrq_emb_bgm.detach()
557
-
558
-
559
-
560
- self.rvq_bestrq_emb.eval()
561
- quantized_bestrq_emb, codes_bestrq_emb, *_ = self.rvq_bestrq_emb(bestrq_emb) # b,d,t
562
-
563
- self.rvq_bestrq_bgm_emb.eval()
564
- quantized_bestrq_emb_bgm, codes_bestrq_emb_bgm, *_ = self.rvq_bestrq_bgm_emb(bestrq_emb_bgm) # b,d,t
565
-
566
-
567
- if('spk' in additional_feats):
568
- self.xvecmodel.eval()
569
- spk_embeds = self.extract_spk_embeds(input_audios)
570
- else:
571
- spk_embeds = None
572
-
573
- # return [codes_prompt, codes_wav2vec], [prompt_embeds, wav2vec_embeds], spk_embeds
574
- # return [codes_prompt_7, codes_prompt_13, codes_prompt_20, codes_wav2vec_half, codes_wav2vec_last], [prompt_embeds_7, prompt_embeds_13, prompt_embeds_20, wav2vec_embeds_half, wav2vec_embeds_last], spk_embeds
575
- # return [codes_bestrq_middle, codes_bestrq_last], [bestrq_middle, bestrq_last], spk_embeds
576
- return [codes_bestrq_emb,codes_bestrq_emb_bgm], [bestrq_emb,bestrq_emb_bgm], spk_embeds
577
- # return [codes_prompt_13, codes_wav2vec_last], [prompt_embeds_13, wav2vec_embeds_last], spk_embeds
578
-
579
-
580
- @torch.no_grad()
581
- def inference_codes(self, codes, spk_embeds, true_latents, latent_length, additional_feats,incontext_length=127,
582
- guidance_scale=2, num_steps=20,
583
- disable_progress=True, scenario='start_seg'):
584
- classifier_free_guidance = guidance_scale > 1.0
585
- device = self.device
586
- dtype = self.dtype
587
- # codes_bestrq_middle, codes_bestrq_last = codes
588
- codes_bestrq_emb,codes_bestrq_emb_bgm = codes
589
-
590
-
591
- batch_size = codes_bestrq_emb.shape[0]
592
-
593
-
594
- quantized_bestrq_emb,_,_=self.rvq_bestrq_emb.from_codes(codes_bestrq_emb)
595
- quantized_bestrq_emb_bgm,_,_=self.rvq_bestrq_bgm_emb.from_codes(codes_bestrq_emb_bgm)
596
- quantized_bestrq_emb = quantized_bestrq_emb.permute(0,2,1).contiguous()
597
- quantized_bestrq_emb_bgm = quantized_bestrq_emb_bgm.permute(0,2,1).contiguous()
598
- if('spk' in additional_feats):
599
- spk_embeds = spk_embeds.repeat(1,1,quantized_bestrq_emb.shape[-2],1).detach()
600
-
601
- num_frames = quantized_bestrq_emb.shape[1]
602
-
603
- num_channels_latents = self.num_channels
604
- shape = (batch_size, num_frames, 64)
605
- latents = randn_tensor(shape, generator=None, device=device, dtype=dtype)
606
-
607
-
608
-
609
- latent_masks = torch.zeros(latents.shape[0], latents.shape[1], dtype=torch.int64, device=latents.device)
610
- latent_masks[:,0:latent_length] = 2
611
- if(scenario=='other_seg'):
612
- latent_masks[:,0:incontext_length] = 1
613
-
614
-
615
-
616
- quantized_bestrq_emb = (latent_masks > 0.5).unsqueeze(-1) * quantized_bestrq_emb \
617
- + (latent_masks < 0.5).unsqueeze(-1) * self.zero_cond_embedding1.reshape(1,1,1024)
618
- quantized_bestrq_emb_bgm = (latent_masks > 0.5).unsqueeze(-1) * quantized_bestrq_emb_bgm \
619
- + (latent_masks < 0.5).unsqueeze(-1) * self.zero_cond_embedding1.reshape(1,1,1024)
620
- true_latents = true_latents.permute(0,2,1).contiguous()
621
- true_latents = self.normfeat.project_sample(true_latents)
622
- true_latents = true_latents.permute(0,2,1).contiguous()
623
- incontext_latents = true_latents * ((latent_masks > 0.5) * (latent_masks < 1.5)).unsqueeze(-1).float()
624
- incontext_length = ((latent_masks > 0.5) * (latent_masks < 1.5)).sum(-1)[0]
625
-
626
-
627
- attention_mask=(latent_masks > 0.5)
628
- B, L = attention_mask.size()
629
- attention_mask = attention_mask.view(B, 1, L)
630
- attention_mask = attention_mask * attention_mask.transpose(-1, -2)
631
- attention_mask = attention_mask.unsqueeze(1)
632
- latent_mask_input = self.mask_emb(latent_masks)
633
-
634
- if('spk' in additional_feats):
635
- # additional_model_input = torch.cat([quantized_bestrq_middle, quantized_bestrq_last, spk_embeds],1)
636
- additional_model_input = torch.cat([quantized_bestrq_emb,quantized_bestrq_emb_bgm, spk_embeds],2)
637
- else:
638
- # additional_model_input = torch.cat([quantized_bestrq_middle, quantized_bestrq_last],1)
639
- additional_model_input = torch.cat([quantized_bestrq_emb,quantized_bestrq_emb_bgm],2)
640
-
641
- temperature = 1.0
642
- t_span = torch.linspace(0, 1, num_steps + 1, device=quantized_bestrq_emb.device)
643
- latents = self.cfm_wrapper.solve_euler(latents * temperature, latent_mask_input,incontext_latents, incontext_length, t_span, additional_model_input,attention_mask, guidance_scale)
644
-
645
- latents[:,0:incontext_length,:] = incontext_latents[:,0:incontext_length,:]
646
- latents = latents.permute(0,2,1).contiguous()
647
- latents = self.normfeat.return_sample(latents)
648
- # latents = latents.permute(0,2,1).contiguous()
649
- return latents
650
-
651
- @torch.no_grad()
652
- def inference(self, input_audios_vocal,input_audios_bgm, lyric, true_latents, latent_length, additional_feats, guidance_scale=2, num_steps=20,
653
- disable_progress=True,layer_vocal=7,layer_bgm=3,scenario='start_seg'):
654
- codes, embeds, spk_embeds = self.fetch_codes(input_audios_vocal,input_audios_bgm, additional_feats,layer_vocal,layer_bgm)
655
-
656
- latents = self.inference_codes(codes, spk_embeds, true_latents, latent_length, additional_feats, \
657
- guidance_scale=guidance_scale, num_steps=num_steps, \
658
- disable_progress=disable_progress,scenario=scenario)
659
- return latents
660
-
661
- def prepare_latents(self, batch_size, num_frames, num_channels_latents, dtype, device):
662
- divisor = 4
663
- shape = (batch_size, num_channels_latents, num_frames, 32)
664
- if(num_frames%divisor>0):
665
- num_frames = round(num_frames/float(divisor))*divisor
666
- shape = (batch_size, num_channels_latents, num_frames, 32)
667
- latents = randn_tensor(shape, generator=None, device=device, dtype=dtype)
668
- return latents
669
-
670
-
 
1
+ import yaml
2
+ import random
3
+ import inspect
4
+ import numpy as np
5
+ from tqdm import tqdm
6
+ import typing as tp
7
+ from abc import ABC
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ import torchaudio
13
+
14
+ from einops import repeat
15
+ from tools.torch_tools import wav_to_fbank
16
+
17
+ from diffusers.utils.torch_utils import randn_tensor
18
+ from transformers import HubertModel
19
+ from libs.rvq.descript_quantize3 import ResidualVectorQuantize
20
+
21
+ from models_gpt.models.gpt2_rope2_time_new_correct_mask_noncasual_reflow import GPT2Model
22
+ from models_gpt.models.gpt2_config import GPT2Config
23
+
24
+ from torch.cuda.amp import autocast
25
+ from our_MERT_BESTRQ.test import load_model
26
+
27
+ class HubertModelWithFinalProj(HubertModel):
28
+ def __init__(self, config):
29
+ super().__init__(config)
30
+
31
+ # The final projection layer is only used for backward compatibility.
32
+ # Following https://github.com/auspicious3000/contentvec/issues/6
33
+ # Remove this layer is necessary to achieve the desired outcome.
34
+ print("hidden_size:",config.hidden_size)
35
+ print("classifier_proj_size:",config.classifier_proj_size)
36
+ self.final_proj = nn.Linear(config.hidden_size, config.classifier_proj_size)
37
+
38
+
39
+ class SampleProcessor(torch.nn.Module):
40
+ def project_sample(self, x: torch.Tensor):
41
+ """Project the original sample to the 'space' where the diffusion will happen."""
42
+ """Project back from diffusion space to the actual sample space."""
43
+ return z
44
+
45
+ class Feature1DProcessor(SampleProcessor):
46
+ def __init__(self, dim: int = 100, power_std = 1., \
47
+ num_samples: int = 100_000, cal_num_frames: int = 600):
48
+ super().__init__()
49
+
50
+ self.num_samples = num_samples
51
+ self.dim = dim
52
+ self.power_std = power_std
53
+ self.cal_num_frames = cal_num_frames
54
+ self.register_buffer('counts', torch.zeros(1))
55
+ self.register_buffer('sum_x', torch.zeros(dim))
56
+ self.register_buffer('sum_x2', torch.zeros(dim))
57
+ self.register_buffer('sum_target_x2', torch.zeros(dim))
58
+ self.counts: torch.Tensor
59
+ self.sum_x: torch.Tensor
60
+ self.sum_x2: torch.Tensor
61
+
62
+ @property
63
+ def mean(self):
64
+ mean = self.sum_x / self.counts
65
+ if(self.counts < 10):
66
+ mean = torch.zeros_like(mean)
67
+ return mean
68
+
69
+ @property
70
+ def std(self):
71
+ std = (self.sum_x2 / self.counts - self.mean**2).clamp(min=0).sqrt()
72
+ if(self.counts < 10):
73
+ std = torch.ones_like(std)
74
+ return std
75
+
76
+ @property
77
+ def target_std(self):
78
+ return 1
79
+
80
+ def project_sample(self, x: torch.Tensor):
81
+ assert x.dim() == 3
82
+ if self.counts.item() < self.num_samples:
83
+ self.counts += len(x)
84
+ self.sum_x += x[:,:,0:self.cal_num_frames].mean(dim=(2,)).sum(dim=0)
85
+ self.sum_x2 += x[:,:,0:self.cal_num_frames].pow(2).mean(dim=(2,)).sum(dim=0)
86
+ rescale = (self.target_std / self.std.clamp(min=1e-12)) ** self.power_std # same output size
87
+ x = (x - self.mean.view(1, -1, 1)) * rescale.view(1, -1, 1)
88
+ return x
89
+
90
+ def return_sample(self, x: torch.Tensor):
91
+ assert x.dim() == 3
92
+ rescale = (self.std / self.target_std) ** self.power_std
93
+ x = x * rescale.view(1, -1, 1) + self.mean.view(1, -1, 1)
94
+ return x
95
+
96
+ def pad_or_tunc_tolen(prior_text_encoder_hidden_states, prior_text_mask, prior_prompt_embeds, len_size=77):
97
+ if(prior_text_encoder_hidden_states.shape[1]<len_size):
98
+ prior_text_encoder_hidden_states = torch.cat([prior_text_encoder_hidden_states, \
99
+ torch.zeros(prior_text_mask.shape[0], len_size-prior_text_mask.shape[1], \
100
+ prior_text_encoder_hidden_states.shape[2], device=prior_text_mask.device, \
101
+ dtype=prior_text_encoder_hidden_states.dtype)],1)
102
+ prior_text_mask = torch.cat([prior_text_mask, torch.zeros(prior_text_mask.shape[0], len_size-prior_text_mask.shape[1], device=prior_text_mask.device, dtype=prior_text_mask.dtype)],1)
103
+ else:
104
+ prior_text_encoder_hidden_states = prior_text_encoder_hidden_states[:,0:len_size]
105
+ prior_text_mask = prior_text_mask[:,0:len_size]
106
+ prior_text_encoder_hidden_states = prior_text_encoder_hidden_states.permute(0,2,1).contiguous()
107
+ return prior_text_encoder_hidden_states, prior_text_mask, prior_prompt_embeds
108
+
109
+ class BASECFM(torch.nn.Module, ABC):
110
+ def __init__(
111
+ self,
112
+ estimator,
113
+ mlp
114
+ ):
115
+ super().__init__()
116
+ self.sigma_min = 1e-4
117
+
118
+ self.estimator = estimator
119
+ self.mlp = mlp
120
+
121
+ @torch.inference_mode()
122
+ def forward(self, mu, n_timesteps, temperature=1.0):
123
+ """Forward diffusion
124
+
125
+ Args:
126
+ mu (torch.Tensor): output of encoder
127
+ shape: (batch_size, n_channels, mel_timesteps, n_feats)
128
+ n_timesteps (int): number of diffusion steps
129
+ temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
130
+
131
+ Returns:
132
+ sample: generated mel-spectrogram
133
+ shape: (batch_size, n_channels, mel_timesteps, n_feats)
134
+ """
135
+ z = torch.randn_like(mu) * temperature
136
+ t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device)
137
+ return self.solve_euler(z, t_span=t_span)
138
+
139
+ def solve_euler(self, x, latent_mask_input,incontext_x, incontext_length, t_span, mu,attention_mask, guidance_scale):
140
+ """
141
+ Fixed euler solver for ODEs.
142
+ Args:
143
+ x (torch.Tensor): random noise
144
+ t_span (torch.Tensor): n_timesteps interpolated
145
+ shape: (n_timesteps + 1,)
146
+ mu (torch.Tensor): output of encoder
147
+ shape: (batch_size, n_channels, mel_timesteps, n_feats)
148
+ """
149
+ t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
150
+ noise = x.clone()
151
+
152
+ # I am storing this because I can later plot it by putting a debugger here and saving it to a file
153
+ # Or in future might add like a return_all_steps flag
154
+ sol = []
155
+
156
+ for step in tqdm(range(1, len(t_span))):
157
+ x[:,0:incontext_length,:] = (1 - (1 - self.sigma_min) * t) * noise[:,0:incontext_length,:] + t * incontext_x[:,0:incontext_length,:]
158
+ if(guidance_scale > 1.0):
159
+
160
+ model_input = torch.cat([ \
161
+ torch.cat([latent_mask_input, latent_mask_input], 0), \
162
+ torch.cat([incontext_x, incontext_x], 0), \
163
+ torch.cat([torch.zeros_like(mu), mu], 0), \
164
+ torch.cat([x, x], 0), \
165
+ ], 2)
166
+ timestep=t.unsqueeze(-1).repeat(2)
167
+
168
+ dphi_dt = self.estimator(inputs_embeds=model_input, attention_mask=attention_mask,time_step=timestep).last_hidden_state
169
+ dphi_dt_uncond, dhpi_dt_cond = dphi_dt.chunk(2,0)
170
+ dphi_dt = dphi_dt_uncond + guidance_scale * (dhpi_dt_cond - dphi_dt_uncond)
171
+ else:
172
+ model_input = torch.cat([latent_mask_input, incontext_x, mu, x], 2)
173
+ timestep=t.unsqueeze(-1)
174
+ dphi_dt = self.estimator(inputs_embeds=model_input, attention_mask=attention_mask,time_step=timestep).last_hidden_state
175
+
176
+ dphi_dt = dphi_dt[: ,:, -x.shape[2]:]
177
+ x = x + dt * dphi_dt
178
+ t = t + dt
179
+ sol.append(x)
180
+ if step < len(t_span) - 1:
181
+ dt = t_span[step + 1] - t
182
+
183
+ return sol[-1]
184
+
185
+ def projection_loss(self,hidden_proj, bestrq_emb):
186
+ bsz = hidden_proj.shape[0]
187
+
188
+ hidden_proj_normalized = F.normalize(hidden_proj, dim=-1)
189
+ bestrq_emb_normalized = F.normalize(bestrq_emb, dim=-1)
190
+
191
+ proj_loss = -(hidden_proj_normalized * bestrq_emb_normalized).sum(dim=-1)
192
+ proj_loss = 1+proj_loss.mean()
193
+
194
+ return proj_loss
195
+
196
+ def compute_loss(self, x1, mu, latent_masks,attention_mask,wav2vec_embeds, validation_mode=False):
197
+ """Computes diffusion loss
198
+
199
+ Args:
200
+ x1 (torch.Tensor): Target
201
+ shape: (batch_size, n_channels, mel_timesteps, n_feats)
202
+ mu (torch.Tensor): output of encoder
203
+ shape: (batch_size, n_channels, mel_timesteps, n_feats)
204
+
205
+ Returns:
206
+ loss: conditional flow matching loss
207
+ y: conditional flow
208
+ shape: (batch_size, n_channels, mel_timesteps, n_feats)
209
+ """
210
+ b = mu[0].shape[0]
211
+ len_x = x1.shape[2]
212
+ # random timestep
213
+ if(validation_mode):
214
+ t = torch.ones([b, 1, 1], device=mu[0].device, dtype=mu[0].dtype) * 0.5
215
+ else:
216
+ t = torch.rand([b, 1, 1], device=mu[0].device, dtype=mu[0].dtype)
217
+ # sample noise p(x_0)
218
+ z = torch.randn_like(x1)
219
+
220
+ y = (1 - (1 - self.sigma_min) * t) * z + t * x1
221
+ u = x1 - (1 - self.sigma_min) * z
222
+ model_input = torch.cat([*mu,y], 2)
223
+ t=t.squeeze(-1).squeeze(-1)
224
+ out = self.estimator(inputs_embeds=model_input, attention_mask=attention_mask,time_step=t,output_hidden_states=True)
225
+ hidden_layer_7 = out.hidden_states[7]
226
+ hidden_proj = self.mlp(hidden_layer_7)
227
+ out = out.last_hidden_state
228
+ out=out[:,:,-len_x:]
229
+
230
+ weight = (latent_masks > 1.5).unsqueeze(-1).repeat(1, 1, out.shape[-1]).float() + (latent_masks < 0.5).unsqueeze(-1).repeat(1, 1, out.shape[-1]).float() * 0.01
231
+ loss_re = F.mse_loss(out * weight, u * weight, reduction="sum") / weight.sum()
232
+ loss_cos = self.projection_loss(hidden_proj, wav2vec_embeds)
233
+ loss = loss_re + loss_cos * 0.5
234
+ return loss, loss_re, loss_cos
235
+
236
+ class PromptCondAudioDiffusion(nn.Module):
237
+ def __init__(
238
+ self,
239
+ num_channels,
240
+ unet_model_name=None,
241
+ unet_model_config_path=None,
242
+ snr_gamma=None,
243
+ uncondition=True,
244
+ out_paint=False,
245
+ ):
246
+ super().__init__()
247
+
248
+ assert unet_model_name is not None or unet_model_config_path is not None, "Either UNet pretrain model name or a config file path is required"
249
+
250
+ self.unet_model_name = unet_model_name
251
+ self.unet_model_config_path = unet_model_config_path
252
+ self.snr_gamma = snr_gamma
253
+ self.uncondition = uncondition
254
+ self.num_channels = num_channels
255
+
256
+ # https://huggingface.co/docs/diffusers/v0.14.0/en/api/schedulers/overview
257
+ self.normfeat = Feature1DProcessor(dim=64)
258
+
259
+ self.sample_rate = 48000
260
+ self.num_samples_perseg = self.sample_rate * 20 // 1000
261
+ self.rsp48toclap = torchaudio.transforms.Resample(48000, 24000)
262
+ self.rsq48towav2vec = torchaudio.transforms.Resample(48000, 16000)
263
+ # self.wav2vec = Wav2Vec2BertModel.from_pretrained("facebook/w2v-bert-2.0", trust_remote_code=True)
264
+ # self.wav2vec_processor = AutoFeatureExtractor.from_pretrained("facebook/w2v-bert-2.0", trust_remote_code=True)
265
+ self.bestrq = load_model(
266
+ model_dir='codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq',
267
+ checkpoint_dir='ckpt/encode-s12k.pt',
268
+ )
269
+ self.rsq48tobestrq = torchaudio.transforms.Resample(48000, 24000)
270
+ self.rsq48tohubert = torchaudio.transforms.Resample(48000, 16000)
271
+ for v in self.bestrq.parameters():v.requires_grad = False
272
+ self.rvq_bestrq_emb = ResidualVectorQuantize(input_dim = 1024, n_codebooks = 1, codebook_size = 16_384, codebook_dim = 32, quantizer_dropout = 0.0, stale_tolerance=200)
273
+ self.rvq_bestrq_bgm_emb = ResidualVectorQuantize(input_dim = 1024, n_codebooks = 1, codebook_size = 16_384, codebook_dim = 32, quantizer_dropout = 0.0, stale_tolerance=200)
274
+ self.hubert = HubertModelWithFinalProj.from_pretrained("ckpt/models--lengyue233--content-vec-best/snapshots/c0b9ba13db21beaa4053faae94c102ebe326fd68")
275
+ for v in self.hubert.parameters():v.requires_grad = False
276
+ self.zero_cond_embedding1 = nn.Parameter(torch.randn(32*32,))
277
+ # self.xvecmodel = XVECModel()
278
+ config = GPT2Config(n_positions=1000,n_layer=16,n_head=20,n_embd=2200,n_inner=4400)
279
+ unet = GPT2Model(config)
280
+ mlp = nn.Sequential(
281
+ nn.Linear(2200, 1024),
282
+ nn.SiLU(),
283
+ nn.Linear(1024, 1024),
284
+ nn.SiLU(),
285
+ nn.Linear(1024, 768)
286
+ )
287
+ self.set_from = "random"
288
+ self.cfm_wrapper = BASECFM(unet, mlp)
289
+ self.mask_emb = torch.nn.Embedding(3, 24)
290
+ print("Transformer initialized from pretrain.")
291
+ torch.cuda.empty_cache()
292
+
293
+ def compute_snr(self, timesteps):
294
+ """
295
+ Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
296
+ """
297
+ alphas_cumprod = self.noise_scheduler.alphas_cumprod
298
+ sqrt_alphas_cumprod = alphas_cumprod**0.5
299
+ sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
300
+
301
+ # Expand the tensors.
302
+ # Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026
303
+ sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
304
+ while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape):
305
+ sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
306
+ alpha = sqrt_alphas_cumprod.expand(timesteps.shape)
307
+
308
+ sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
309
+ while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape):
310
+ sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None]
311
+ sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape)
312
+
313
+ # Compute SNR.
314
+ snr = (alpha / sigma) ** 2
315
+ return snr
316
+
317
+ def preprocess_audio(self, input_audios, threshold=0.9):
318
+ assert len(input_audios.shape) == 2, input_audios.shape
319
+ norm_value = torch.ones_like(input_audios[:,0])
320
+ max_volume = input_audios.abs().max(dim=-1)[0]
321
+ norm_value[max_volume>threshold] = max_volume[max_volume>threshold] / threshold
322
+ return input_audios/norm_value.unsqueeze(-1)
323
+
324
+ def extract_wav2vec_embeds(self, input_audios,output_len):
325
+ wav2vec_stride = 2
326
+
327
+ wav2vec_embeds = self.hubert(self.rsq48tohubert(input_audios), output_hidden_states=True).hidden_states # 1, 4096, 1024
328
+ wav2vec_embeds_last=wav2vec_embeds[-1]
329
+ wav2vec_embeds_last=torch.nn.functional.interpolate(wav2vec_embeds_last.permute(0, 2, 1), size=output_len, mode='linear', align_corners=False).permute(0, 2, 1)
330
+ return wav2vec_embeds_last
331
+
332
+ def extract_mert_embeds(self, input_audios):
333
+ prompt_stride = 3
334
+ inputs = self.clap_embd_extractor.mulan.audio.processor(self.rsp48toclap(input_audios), sampling_rate=self.clap_embd_extractor.mulan.audio.sr, return_tensors="pt")
335
+ input_values = inputs['input_values'].squeeze(0).to(input_audios.device, dtype = input_audios.dtype)
336
+ prompt_embeds = self.clap_embd_extractor.mulan.audio.model(input_values, output_hidden_states=True).hidden_states # batch_size, Time steps, 1024
337
+ mert_emb= prompt_embeds[-1]
338
+ mert_emb = torch.nn.functional.interpolate(mert_emb.permute(0, 2, 1), size=375, mode='linear', align_corners=False).permute(0, 2, 1)
339
+
340
+ return mert_emb
341
+
342
+ def extract_bestrq_embeds(self, input_audio_vocal_0,input_audio_vocal_1,layer):
343
+ input_wav_mean = (input_audio_vocal_0 + input_audio_vocal_1) / 2.0
344
+ input_wav_mean = self.bestrq(self.rsq48tobestrq(input_wav_mean), features_only = True)
345
+ layer_results = input_wav_mean['layer_results']
346
+ bestrq_emb = layer_results[layer]
347
+ bestrq_emb = bestrq_emb.permute(0,2,1).contiguous()
348
+ return bestrq_emb
349
+
350
+
351
+ def extract_spk_embeds(self, input_audios):
352
+ spk_embeds = self.xvecmodel(self.rsq48towav2vec(input_audios))
353
+ spk_embeds = self.spk_linear(spk_embeds).reshape(spk_embeds.shape[0], 16, 1, 32)
354
+ return spk_embeds
355
+
356
+ def extract_lyric_feats(self, lyric):
357
+ with torch.no_grad():
358
+ try:
359
+ text_encoder_hidden_states, text_mask, text_prompt_embeds = self.clap_embd_extractor(texts = lyric, return_one=False)
360
+ except:
361
+ text_encoder_hidden_states, text_mask, text_prompt_embeds = self.clap_embd_extractor(texts = [""] * len(lyric), return_one=False)
362
+ text_encoder_hidden_states = text_encoder_hidden_states.to(self.device)
363
+ text_mask = text_mask.to(self.device)
364
+ text_encoder_hidden_states, text_mask, text_prompt_embeds = \
365
+ pad_or_tunc_tolen(text_encoder_hidden_states, text_mask, text_prompt_embeds)
366
+ text_encoder_hidden_states = text_encoder_hidden_states.permute(0,2,1).contiguous()
367
+ return text_encoder_hidden_states, text_mask
368
+
369
+ def extract_energy_bar(self, input_audios):
370
+ if(input_audios.shape[-1] % self.num_samples_perseg > 0):
371
+ energy_bar = input_audios[:,:-1 * (input_audios.shape[-1] % self.num_samples_perseg)].reshape(input_audios.shape[0],-1,self.num_samples_perseg)
372
+ else:
373
+ energy_bar = input_audios.reshape(input_audios.shape[0],-1,self.num_samples_perseg)
374
+ energy_bar = (energy_bar.pow(2.0).mean(-1).sqrt() + 1e-6).log10() * 20 # B T
375
+ energy_bar = (energy_bar / 2.0 + 16).clamp(0,16).int()
376
+ energy_embedding = self.energy_embedding(energy_bar)
377
+ energy_embedding = energy_embedding.view(energy_embedding.shape[0], energy_embedding.shape[1] // 2, 2, 32).reshape(energy_embedding.shape[0], energy_embedding.shape[1] // 2, 64).permute(0,2,1) # b 128 t
378
+ return energy_embedding
379
+
380
+ def forward(self, input_audios_vocal,input_audios_bgm, lyric, latents, latent_masks, validation_mode=False, \
381
+ additional_feats = ['spk', 'lyric'], \
382
+ train_rvq=True, train_ssl=False,layer_vocal=7,layer_bgm=7):
383
+ if not hasattr(self,"device"):
384
+ self.device = input_audios_vocal.device
385
+ if not hasattr(self,"dtype"):
386
+ self.dtype = input_audios_vocal.dtype
387
+ device = self.device
388
+ input_audio_vocal_0 = input_audios_vocal[:,0,:]
389
+ input_audio_vocal_1 = input_audios_vocal[:,1,:]
390
+ input_audio_vocal_0 = self.preprocess_audio(input_audio_vocal_0)
391
+ input_audio_vocal_1 = self.preprocess_audio(input_audio_vocal_1)
392
+ input_audios_vocal_wav2vec = (input_audio_vocal_0 + input_audio_vocal_1) / 2.0
393
+
394
+ input_audio_bgm_0 = input_audios_bgm[:,0,:]
395
+ input_audio_bgm_1 = input_audios_bgm[:,1,:]
396
+ input_audio_bgm_0 = self.preprocess_audio(input_audio_bgm_0)
397
+ input_audio_bgm_1 = self.preprocess_audio(input_audio_bgm_1)
398
+ input_audios_bgm_wav2vec = (input_audio_bgm_0 + input_audio_bgm_1) / 2.0
399
+
400
+ if(train_ssl):
401
+ self.wav2vec.train()
402
+ wav2vec_embeds = self.extract_wav2vec_embeds(input_audios)
403
+ self.clap_embd_extractor.train()
404
+ prompt_embeds = self.extract_mert_embeds(input_audios)
405
+ if('spk' in additional_feats):
406
+ self.xvecmodel.train()
407
+ spk_embeds = self.extract_spk_embeds(input_audios).repeat(1,1,prompt_embeds.shape[-1]//2,1)
408
+ else:
409
+ with torch.no_grad():
410
+ with autocast(enabled=False):
411
+ bestrq_emb = self.extract_bestrq_embeds(input_audio_vocal_0,input_audio_vocal_1,layer_vocal)
412
+ bestrq_emb_bgm = self.extract_bestrq_embeds(input_audio_bgm_0,input_audio_bgm_1,layer_bgm)
413
+ # mert_emb = self.extract_mert_embeds(input_audios_mert)
414
+ output_len = bestrq_emb.shape[2]
415
+ wav2vec_embeds = self.extract_wav2vec_embeds(input_audios_vocal_wav2vec+input_audios_bgm_wav2vec,output_len)
416
+
417
+
418
+ bestrq_emb = bestrq_emb.detach()
419
+ bestrq_emb_bgm = bestrq_emb_bgm.detach()
420
+
421
+ if('lyric' in additional_feats):
422
+ text_encoder_hidden_states, text_mask = self.extract_lyric_feats(lyric)
423
+ else:
424
+ text_encoder_hidden_states, text_mask = None, None
425
+
426
+
427
+ if(train_rvq):
428
+ quantized_bestrq_emb, _, _, commitment_loss_bestrq_emb, codebook_loss_bestrq_emb,_ = self.rvq_bestrq_emb(bestrq_emb) # b,d,t
429
+ quantized_bestrq_emb_bgm, _, _, commitment_loss_bestrq_emb_bgm, codebook_loss_bestrq_emb_bgm,_ = self.rvq_bestrq_bgm_emb(bestrq_emb_bgm) # b,d,t
430
+ else:
431
+ bestrq_emb = bestrq_emb.float()
432
+ self.rvq_bestrq_emb.eval()
433
+ # with autocast(enabled=False):
434
+ quantized_bestrq_emb, _, _, commitment_loss_bestrq_emb, codebook_loss_bestrq_emb,_ = self.rvq_bestrq_emb(bestrq_emb) # b,d,t
435
+ commitment_loss_bestrq_emb = commitment_loss_bestrq_emb.detach()
436
+ codebook_loss_bestrq_emb = codebook_loss_bestrq_emb.detach()
437
+ quantized_bestrq_emb = quantized_bestrq_emb.detach()
438
+
439
+ commitment_loss = commitment_loss_bestrq_emb+commitment_loss_bestrq_emb_bgm
440
+ codebook_loss = codebook_loss_bestrq_emb+codebook_loss_bestrq_emb_bgm
441
+
442
+
443
+ alpha=1
444
+ quantized_bestrq_emb = quantized_bestrq_emb * alpha + bestrq_emb * (1-alpha)
445
+ quantized_bestrq_emb_bgm = quantized_bestrq_emb_bgm * alpha + bestrq_emb_bgm * (1-alpha)
446
+
447
+
448
+
449
+
450
+ scenario = np.random.choice(['start_seg', 'other_seg'])
451
+ if(scenario == 'other_seg'):
452
+ for binx in range(input_audios_vocal.shape[0]):
453
+ # latent_masks[binx,0:64] = 1
454
+ latent_masks[binx,0:random.randint(64,128)] = 1
455
+ quantized_bestrq_emb = quantized_bestrq_emb.permute(0,2,1).contiguous()
456
+ quantized_bestrq_emb_bgm = quantized_bestrq_emb_bgm.permute(0,2,1).contiguous()
457
+ quantized_bestrq_emb = (latent_masks > 0.5).unsqueeze(-1) * quantized_bestrq_emb \
458
+ + (latent_masks < 0.5).unsqueeze(-1) * self.zero_cond_embedding1.reshape(1,1,1024)
459
+ quantized_bestrq_emb_bgm = (latent_masks > 0.5).unsqueeze(-1) * quantized_bestrq_emb_bgm \
460
+ + (latent_masks < 0.5).unsqueeze(-1) * self.zero_cond_embedding1.reshape(1,1,1024)
461
+
462
+
463
+
464
+
465
+ if self.uncondition:
466
+ mask_indices = [k for k in range(quantized_bestrq_emb.shape[0]) if random.random() < 0.1]
467
+ if len(mask_indices) > 0:
468
+ quantized_bestrq_emb[mask_indices] = 0
469
+ quantized_bestrq_emb_bgm[mask_indices] = 0
470
+ latents = latents.permute(0,2,1).contiguous()
471
+ latents = self.normfeat.project_sample(latents)
472
+ latents = latents.permute(0,2,1).contiguous()
473
+ incontext_latents = latents * ((latent_masks > 0.5) * (latent_masks < 1.5)).unsqueeze(-1).float()
474
+ attention_mask=(latent_masks > 0.5)
475
+ B, L = attention_mask.size()
476
+ attention_mask = attention_mask.view(B, 1, L)
477
+ attention_mask = attention_mask * attention_mask.transpose(-1, -2)
478
+ attention_mask = attention_mask.unsqueeze(1)
479
+ latent_mask_input = self.mask_emb(latent_masks)
480
+ loss,loss_re, loss_cos = self.cfm_wrapper.compute_loss(latents, [latent_mask_input,incontext_latents, quantized_bestrq_emb,quantized_bestrq_emb_bgm], latent_masks,attention_mask,wav2vec_embeds, validation_mode=validation_mode)
481
+ return loss,loss_re, loss_cos, commitment_loss.mean(), codebook_loss.mean()
482
+
483
+ def init_device_dtype(self, device, dtype):
484
+ self.device = device
485
+ self.dtype = dtype
486
+
487
+ @torch.no_grad()
488
+ def fetch_codes(self, input_audios_vocal,input_audios_bgm, additional_feats,layer_vocal=7,layer_bgm=7):
489
+ input_audio_vocal_0 = input_audios_vocal[[0],:]
490
+ input_audio_vocal_1 = input_audios_vocal[[1],:]
491
+ input_audio_vocal_0 = self.preprocess_audio(input_audio_vocal_0)
492
+ input_audio_vocal_1 = self.preprocess_audio(input_audio_vocal_1)
493
+ input_audios_vocal_wav2vec = (input_audio_vocal_0 + input_audio_vocal_1) / 2.0
494
+
495
+ input_audio_bgm_0 = input_audios_bgm[[0],:]
496
+ input_audio_bgm_1 = input_audios_bgm[[1],:]
497
+ input_audio_bgm_0 = self.preprocess_audio(input_audio_bgm_0)
498
+ input_audio_bgm_1 = self.preprocess_audio(input_audio_bgm_1)
499
+ input_audios_bgm_wav2vec = (input_audio_bgm_0 + input_audio_bgm_1) / 2.0
500
+
501
+ self.bestrq.eval()
502
+
503
+ # bestrq_middle,bestrq_last = self.extract_bestrq_embeds(input_audios)
504
+ # bestrq_middle = bestrq_middle.detach()
505
+ # bestrq_last = bestrq_last.detach()
506
+ bestrq_emb = self.extract_bestrq_embeds(input_audio_vocal_0,input_audio_vocal_1,layer_vocal)
507
+ bestrq_emb = bestrq_emb.detach()
508
+
509
+ bestrq_emb_bgm = self.extract_bestrq_embeds(input_audio_bgm_0,input_audio_bgm_1,layer_bgm)
510
+ bestrq_emb_bgm = bestrq_emb_bgm.detach()
511
+
512
+
513
+
514
+ self.rvq_bestrq_emb.eval()
515
+ quantized_bestrq_emb, codes_bestrq_emb, *_ = self.rvq_bestrq_emb(bestrq_emb) # b,d,t
516
+
517
+ self.rvq_bestrq_bgm_emb.eval()
518
+ quantized_bestrq_emb_bgm, codes_bestrq_emb_bgm, *_ = self.rvq_bestrq_bgm_emb(bestrq_emb_bgm) # b,d,t
519
+
520
+
521
+ if('spk' in additional_feats):
522
+ self.xvecmodel.eval()
523
+ spk_embeds = self.extract_spk_embeds(input_audios)
524
+ else:
525
+ spk_embeds = None
526
+
527
+ # return [codes_prompt, codes_wav2vec], [prompt_embeds, wav2vec_embeds], spk_embeds
528
+ # return [codes_prompt_7, codes_prompt_13, codes_prompt_20, codes_wav2vec_half, codes_wav2vec_last], [prompt_embeds_7, prompt_embeds_13, prompt_embeds_20, wav2vec_embeds_half, wav2vec_embeds_last], spk_embeds
529
+ # return [codes_bestrq_middle, codes_bestrq_last], [bestrq_middle, bestrq_last], spk_embeds
530
+ return [codes_bestrq_emb,codes_bestrq_emb_bgm], [bestrq_emb,bestrq_emb_bgm], spk_embeds
531
+ # return [codes_prompt_13, codes_wav2vec_last], [prompt_embeds_13, wav2vec_embeds_last], spk_embeds
532
+
533
+ @torch.no_grad()
534
+ def fetch_codes_batch(self, input_audios_vocal, input_audios_bgm, additional_feats,layer_vocal=7,layer_bgm=7):
535
+ input_audio_vocal_0 = input_audios_vocal[:,0,:]
536
+ input_audio_vocal_1 = input_audios_vocal[:,1,:]
537
+ input_audio_vocal_0 = self.preprocess_audio(input_audio_vocal_0)
538
+ input_audio_vocal_1 = self.preprocess_audio(input_audio_vocal_1)
539
+ input_audios_vocal_wav2vec = (input_audio_vocal_0 + input_audio_vocal_1) / 2.0
540
+
541
+ input_audio_bgm_0 = input_audios_bgm[:,0,:]
542
+ input_audio_bgm_1 = input_audios_bgm[:,1,:]
543
+ input_audio_bgm_0 = self.preprocess_audio(input_audio_bgm_0)
544
+ input_audio_bgm_1 = self.preprocess_audio(input_audio_bgm_1)
545
+ input_audios_bgm_wav2vec = (input_audio_bgm_0 + input_audio_bgm_1) / 2.0
546
+
547
+ self.bestrq.eval()
548
+
549
+ # bestrq_middle,bestrq_last = self.extract_bestrq_embeds(input_audios)
550
+ # bestrq_middle = bestrq_middle.detach()
551
+ # bestrq_last = bestrq_last.detach()
552
+ bestrq_emb = self.extract_bestrq_embeds(input_audio_vocal_0,input_audio_vocal_1,layer_vocal)
553
+ bestrq_emb = bestrq_emb.detach()
554
+
555
+ bestrq_emb_bgm = self.extract_bestrq_embeds(input_audio_bgm_0,input_audio_bgm_1,layer_bgm)
556
+ bestrq_emb_bgm = bestrq_emb_bgm.detach()
557
+
558
+
559
+
560
+ self.rvq_bestrq_emb.eval()
561
+ quantized_bestrq_emb, codes_bestrq_emb, *_ = self.rvq_bestrq_emb(bestrq_emb) # b,d,t
562
+
563
+ self.rvq_bestrq_bgm_emb.eval()
564
+ quantized_bestrq_emb_bgm, codes_bestrq_emb_bgm, *_ = self.rvq_bestrq_bgm_emb(bestrq_emb_bgm) # b,d,t
565
+
566
+
567
+ if('spk' in additional_feats):
568
+ self.xvecmodel.eval()
569
+ spk_embeds = self.extract_spk_embeds(input_audios)
570
+ else:
571
+ spk_embeds = None
572
+
573
+ # return [codes_prompt, codes_wav2vec], [prompt_embeds, wav2vec_embeds], spk_embeds
574
+ # return [codes_prompt_7, codes_prompt_13, codes_prompt_20, codes_wav2vec_half, codes_wav2vec_last], [prompt_embeds_7, prompt_embeds_13, prompt_embeds_20, wav2vec_embeds_half, wav2vec_embeds_last], spk_embeds
575
+ # return [codes_bestrq_middle, codes_bestrq_last], [bestrq_middle, bestrq_last], spk_embeds
576
+ return [codes_bestrq_emb,codes_bestrq_emb_bgm], [bestrq_emb,bestrq_emb_bgm], spk_embeds
577
+ # return [codes_prompt_13, codes_wav2vec_last], [prompt_embeds_13, wav2vec_embeds_last], spk_embeds
578
+
579
+
580
+ @torch.no_grad()
581
+ def inference_codes(self, codes, spk_embeds, true_latents, latent_length, additional_feats,incontext_length=127,
582
+ guidance_scale=2, num_steps=20,
583
+ disable_progress=True, scenario='start_seg'):
584
+ classifier_free_guidance = guidance_scale > 1.0
585
+ device = self.device
586
+ dtype = self.dtype
587
+ # codes_bestrq_middle, codes_bestrq_last = codes
588
+ codes_bestrq_emb,codes_bestrq_emb_bgm = codes
589
+
590
+
591
+ batch_size = codes_bestrq_emb.shape[0]
592
+
593
+
594
+ quantized_bestrq_emb,_,_=self.rvq_bestrq_emb.from_codes(codes_bestrq_emb)
595
+ quantized_bestrq_emb_bgm,_,_=self.rvq_bestrq_bgm_emb.from_codes(codes_bestrq_emb_bgm)
596
+ quantized_bestrq_emb = quantized_bestrq_emb.permute(0,2,1).contiguous()
597
+ quantized_bestrq_emb_bgm = quantized_bestrq_emb_bgm.permute(0,2,1).contiguous()
598
+ if('spk' in additional_feats):
599
+ spk_embeds = spk_embeds.repeat(1,1,quantized_bestrq_emb.shape[-2],1).detach()
600
+
601
+ num_frames = quantized_bestrq_emb.shape[1]
602
+
603
+ num_channels_latents = self.num_channels
604
+ shape = (batch_size, num_frames, 64)
605
+ latents = randn_tensor(shape, generator=None, device=device, dtype=dtype)
606
+
607
+
608
+
609
+ latent_masks = torch.zeros(latents.shape[0], latents.shape[1], dtype=torch.int64, device=latents.device)
610
+ latent_masks[:,0:latent_length] = 2
611
+ if(scenario=='other_seg'):
612
+ latent_masks[:,0:incontext_length] = 1
613
+
614
+
615
+
616
+ quantized_bestrq_emb = (latent_masks > 0.5).unsqueeze(-1) * quantized_bestrq_emb \
617
+ + (latent_masks < 0.5).unsqueeze(-1) * self.zero_cond_embedding1.reshape(1,1,1024)
618
+ quantized_bestrq_emb_bgm = (latent_masks > 0.5).unsqueeze(-1) * quantized_bestrq_emb_bgm \
619
+ + (latent_masks < 0.5).unsqueeze(-1) * self.zero_cond_embedding1.reshape(1,1,1024)
620
+ true_latents = true_latents.permute(0,2,1).contiguous()
621
+ true_latents = self.normfeat.project_sample(true_latents)
622
+ true_latents = true_latents.permute(0,2,1).contiguous()
623
+ incontext_latents = true_latents * ((latent_masks > 0.5) * (latent_masks < 1.5)).unsqueeze(-1).float()
624
+ incontext_length = ((latent_masks > 0.5) * (latent_masks < 1.5)).sum(-1)[0]
625
+
626
+
627
+ attention_mask=(latent_masks > 0.5)
628
+ B, L = attention_mask.size()
629
+ attention_mask = attention_mask.view(B, 1, L)
630
+ attention_mask = attention_mask * attention_mask.transpose(-1, -2)
631
+ attention_mask = attention_mask.unsqueeze(1)
632
+ latent_mask_input = self.mask_emb(latent_masks)
633
+
634
+ if('spk' in additional_feats):
635
+ # additional_model_input = torch.cat([quantized_bestrq_middle, quantized_bestrq_last, spk_embeds],1)
636
+ additional_model_input = torch.cat([quantized_bestrq_emb,quantized_bestrq_emb_bgm, spk_embeds],2)
637
+ else:
638
+ # additional_model_input = torch.cat([quantized_bestrq_middle, quantized_bestrq_last],1)
639
+ additional_model_input = torch.cat([quantized_bestrq_emb,quantized_bestrq_emb_bgm],2)
640
+
641
+ temperature = 1.0
642
+ t_span = torch.linspace(0, 1, num_steps + 1, device=quantized_bestrq_emb.device)
643
+ latents = self.cfm_wrapper.solve_euler(latents * temperature, latent_mask_input,incontext_latents, incontext_length, t_span, additional_model_input,attention_mask, guidance_scale)
644
+
645
+ latents[:,0:incontext_length,:] = incontext_latents[:,0:incontext_length,:]
646
+ latents = latents.permute(0,2,1).contiguous()
647
+ latents = self.normfeat.return_sample(latents)
648
+ # latents = latents.permute(0,2,1).contiguous()
649
+ return latents
650
+
651
+ @torch.no_grad()
652
+ def inference(self, input_audios_vocal,input_audios_bgm, lyric, true_latents, latent_length, additional_feats, guidance_scale=2, num_steps=20,
653
+ disable_progress=True,layer_vocal=7,layer_bgm=3,scenario='start_seg'):
654
+ codes, embeds, spk_embeds = self.fetch_codes(input_audios_vocal,input_audios_bgm, additional_feats,layer_vocal,layer_bgm)
655
+
656
+ latents = self.inference_codes(codes, spk_embeds, true_latents, latent_length, additional_feats, \
657
+ guidance_scale=guidance_scale, num_steps=num_steps, \
658
+ disable_progress=disable_progress,scenario=scenario)
659
+ return latents
660
+
661
+ def prepare_latents(self, batch_size, num_frames, num_channels_latents, dtype, device):
662
+ divisor = 4
663
+ shape = (batch_size, num_channels_latents, num_frames, 32)
664
+ if(num_frames%divisor>0):
665
+ num_frames = round(num_frames/float(divisor))*divisor
666
+ shape = (batch_size, num_channels_latents, num_frames, 32)
667
+ latents = randn_tensor(shape, generator=None, device=device, dtype=dtype)
668
+ return latents
669
+
670
+
codeclm/tokenizer/Flow1dVAE/models/unet_2d_condition_additionalemb.py CHANGED
The diff for this file is too large to render. See raw diff
 
codeclm/tokenizer/Flow1dVAE/models/unet_2d_condition_flow.py CHANGED
The diff for this file is too large to render. See raw diff
 
codeclm/tokenizer/Flow1dVAE/models_gpt/models/tokenizer/pinyin/symbols.py CHANGED
@@ -1,71 +1,71 @@
1
- _pause = ["sil", "eos", "sp", "#0", "#1", "#2", "#3"]
2
-
3
- _initials = [
4
- "^",
5
- "b",
6
- "c",
7
- "ch",
8
- "d",
9
- "f",
10
- "g",
11
- "h",
12
- "j",
13
- "k",
14
- "l",
15
- "m",
16
- "n",
17
- "p",
18
- "q",
19
- "r",
20
- "s",
21
- "sh",
22
- "t",
23
- "x",
24
- "z",
25
- "zh",
26
- ]
27
-
28
- _tones = ["1", "2", "3", "4", "5"]
29
-
30
- _finals = [
31
- "a",
32
- "ai",
33
- "an",
34
- "ang",
35
- "ao",
36
- "e",
37
- "ei",
38
- "en",
39
- "eng",
40
- "er",
41
- "i",
42
- "ia",
43
- "ian",
44
- "iang",
45
- "iao",
46
- "ie",
47
- "ii",
48
- "iii",
49
- "in",
50
- "ing",
51
- "iong",
52
- "iou",
53
- "o",
54
- "ong",
55
- "ou",
56
- "u",
57
- "ua",
58
- "uai",
59
- "uan",
60
- "uang",
61
- "uei",
62
- "uen",
63
- "ueng",
64
- "uo",
65
- "v",
66
- "van",
67
- "ve",
68
- "vn",
69
- ]
70
-
71
- symbols = _pause + _initials + [i + j for i in _finals for j in _tones]
 
1
+ _pause = ["sil", "eos", "sp", "#0", "#1", "#2", "#3"]
2
+
3
+ _initials = [
4
+ "^",
5
+ "b",
6
+ "c",
7
+ "ch",
8
+ "d",
9
+ "f",
10
+ "g",
11
+ "h",
12
+ "j",
13
+ "k",
14
+ "l",
15
+ "m",
16
+ "n",
17
+ "p",
18
+ "q",
19
+ "r",
20
+ "s",
21
+ "sh",
22
+ "t",
23
+ "x",
24
+ "z",
25
+ "zh",
26
+ ]
27
+
28
+ _tones = ["1", "2", "3", "4", "5"]
29
+
30
+ _finals = [
31
+ "a",
32
+ "ai",
33
+ "an",
34
+ "ang",
35
+ "ao",
36
+ "e",
37
+ "ei",
38
+ "en",
39
+ "eng",
40
+ "er",
41
+ "i",
42
+ "ia",
43
+ "ian",
44
+ "iang",
45
+ "iao",
46
+ "ie",
47
+ "ii",
48
+ "iii",
49
+ "in",
50
+ "ing",
51
+ "iong",
52
+ "iou",
53
+ "o",
54
+ "ong",
55
+ "ou",
56
+ "u",
57
+ "ua",
58
+ "uai",
59
+ "uan",
60
+ "uang",
61
+ "uei",
62
+ "uen",
63
+ "ueng",
64
+ "uo",
65
+ "v",
66
+ "van",
67
+ "ve",
68
+ "vn",
69
+ ]
70
+
71
+ symbols = _pause + _initials + [i + j for i in _finals for j in _tones]
codeclm/tokenizer/Flow1dVAE/tools/infer_bsrnnvae441k.py CHANGED
@@ -1,47 +1,47 @@
1
- import json
2
- import torch
3
- from tqdm import tqdm
4
- import torchaudio
5
- import librosa
6
- import os
7
- import math
8
- import numpy as np
9
- from tools.get_bsrnnvae import get_bsrnnvae
10
- import tools.torch_tools as torch_tools
11
-
12
- class Tango:
13
- def __init__(self, \
14
- device="cuda:0"):
15
-
16
- self.sample_rate = 44100
17
- self.device = device
18
-
19
- self.vae = get_bsrnnvae()
20
- self.vae = self.vae.eval().to(device)
21
-
22
- def sound2sound_generate_longterm(self, fname, batch_size=1, duration=15.36, steps=200, disable_progress=False):
23
- """ Genrate audio without condition. """
24
- num_frames = math.ceil(duration * 100. / 8)
25
- with torch.no_grad():
26
- orig_samples, fs = torchaudio.load(fname)
27
- if(fs!=44100):
28
- orig_samples = torchaudio.functional.resample(orig_samples, fs, 44100)
29
- fs = 44100
30
- if(orig_samples.shape[-1]<int(duration*44100*2)):
31
- orig_samples = torch.cat([orig_samples, torch.zeros(orig_samples.shape[0], int(duration*44100*2+480)-orig_samples.shape[-1], \
32
- dtype=orig_samples.dtype, device=orig_samples.device)], -1)
33
- # orig_samples = torch.cat([torch.zeros(orig_samples.shape[0], int(duration * fs)//2, dtype=orig_samples.dtype, device=orig_samples.device), orig_samples, torch.zeros(orig_samples.shape[0], int(duration * fs)//2, dtype=orig_samples.dtype, device=orig_samples.device)], -1).to(self.device)
34
- orig_samples = torch.cat([orig_samples, torch.zeros(orig_samples.shape[0], int(duration * fs)//2, dtype=orig_samples.dtype, device=orig_samples.device)], -1).to(self.device)
35
- if(fs!=44100):orig_samples = torchaudio.functional.resample(orig_samples, fs, 44100)
36
- # resampled_audios = orig_samples[[0],int(4.64*44100):int(35.36*48000)+480].clamp(-1,1)
37
- resampled_audios = orig_samples[[0],0:int(duration*2*44100)+480].clamp(-1,1)
38
- orig_samples = orig_samples[[0],0:int(duration*2*44100)]
39
-
40
- audio = self.vae(orig_samples[:,None,:])[:,0,:]
41
-
42
- if(orig_samples.shape[-1]<audio.shape[-1]):
43
- orig_samples = torch.cat([orig_samples, torch.zeros(orig_samples.shape[0], audio.shape[-1]-orig_samples.shape[-1], dtype=orig_samples.dtype, device=orig_samples.device)],-1)
44
- else:
45
- orig_samples = orig_samples[:,0:audio.shape[-1]]
46
- output = torch.cat([orig_samples.detach().cpu(),audio.detach().cpu()],0)
47
- return output
 
1
+ import json
2
+ import torch
3
+ from tqdm import tqdm
4
+ import torchaudio
5
+ import librosa
6
+ import os
7
+ import math
8
+ import numpy as np
9
+ from tools.get_bsrnnvae import get_bsrnnvae
10
+ import tools.torch_tools as torch_tools
11
+
12
+ class Tango:
13
+ def __init__(self, \
14
+ device="cuda:0"):
15
+
16
+ self.sample_rate = 44100
17
+ self.device = device
18
+
19
+ self.vae = get_bsrnnvae()
20
+ self.vae = self.vae.eval().to(device)
21
+
22
+ def sound2sound_generate_longterm(self, fname, batch_size=1, duration=15.36, steps=200, disable_progress=False):
23
+ """ Genrate audio without condition. """
24
+ num_frames = math.ceil(duration * 100. / 8)
25
+ with torch.no_grad():
26
+ orig_samples, fs = torchaudio.load(fname)
27
+ if(fs!=44100):
28
+ orig_samples = torchaudio.functional.resample(orig_samples, fs, 44100)
29
+ fs = 44100
30
+ if(orig_samples.shape[-1]<int(duration*44100*2)):
31
+ orig_samples = torch.cat([orig_samples, torch.zeros(orig_samples.shape[0], int(duration*44100*2+480)-orig_samples.shape[-1], \
32
+ dtype=orig_samples.dtype, device=orig_samples.device)], -1)
33
+ # orig_samples = torch.cat([torch.zeros(orig_samples.shape[0], int(duration * fs)//2, dtype=orig_samples.dtype, device=orig_samples.device), orig_samples, torch.zeros(orig_samples.shape[0], int(duration * fs)//2, dtype=orig_samples.dtype, device=orig_samples.device)], -1).to(self.device)
34
+ orig_samples = torch.cat([orig_samples, torch.zeros(orig_samples.shape[0], int(duration * fs)//2, dtype=orig_samples.dtype, device=orig_samples.device)], -1).to(self.device)
35
+ if(fs!=44100):orig_samples = torchaudio.functional.resample(orig_samples, fs, 44100)
36
+ # resampled_audios = orig_samples[[0],int(4.64*44100):int(35.36*48000)+480].clamp(-1,1)
37
+ resampled_audios = orig_samples[[0],0:int(duration*2*44100)+480].clamp(-1,1)
38
+ orig_samples = orig_samples[[0],0:int(duration*2*44100)]
39
+
40
+ audio = self.vae(orig_samples[:,None,:])[:,0,:]
41
+
42
+ if(orig_samples.shape[-1]<audio.shape[-1]):
43
+ orig_samples = torch.cat([orig_samples, torch.zeros(orig_samples.shape[0], audio.shape[-1]-orig_samples.shape[-1], dtype=orig_samples.dtype, device=orig_samples.device)],-1)
44
+ else:
45
+ orig_samples = orig_samples[:,0:audio.shape[-1]]
46
+ output = torch.cat([orig_samples.detach().cpu(),audio.detach().cpu()],0)
47
+ return output
codeclm/tokenizer/Flow1dVAE/tools/infer_bsrnnvae441k_vocal.py CHANGED
@@ -1,47 +1,47 @@
1
- import json
2
- import torch
3
- from tqdm import tqdm
4
- import torchaudio
5
- import librosa
6
- import os
7
- import math
8
- import numpy as np
9
- from tools.get_bsrnnvae import get_bsrnnvae
10
- import tools.torch_tools as torch_tools
11
-
12
- class Tango:
13
- def __init__(self, \
14
- device="cuda:0"):
15
-
16
- self.sample_rate = 44100
17
- self.device = device
18
-
19
- self.vae = get_bsrnnvae()
20
- self.vae = self.vae.eval().to(device)
21
-
22
- def sound2sound_generate_longterm(self, fname, batch_size=1, duration=20.48, steps=200, disable_progress=False):
23
- """ Genrate audio without condition. """
24
- num_frames = math.ceil(duration * 100. / 8)
25
- with torch.no_grad():
26
- orig_samples, fs = torchaudio.load(fname)
27
- if(fs!=44100):
28
- orig_samples = torchaudio.functional.resample(orig_samples, fs, 44100)
29
- fs = 44100
30
- if(orig_samples.shape[-1]<int(duration*44100*2)):
31
- orig_samples = torch.cat([orig_samples, torch.zeros(orig_samples.shape[0], int(duration*44100*2+480)-orig_samples.shape[-1], \
32
- dtype=orig_samples.dtype, device=orig_samples.device)], -1)
33
- # orig_samples = torch.cat([torch.zeros(orig_samples.shape[0], int(duration * fs)//2, dtype=orig_samples.dtype, device=orig_samples.device), orig_samples, torch.zeros(orig_samples.shape[0], int(duration * fs)//2, dtype=orig_samples.dtype, device=orig_samples.device)], -1).to(self.device)
34
- orig_samples = torch.cat([orig_samples, torch.zeros(orig_samples.shape[0], int(duration * fs)//2, dtype=orig_samples.dtype, device=orig_samples.device)], -1).to(self.device)
35
- if(fs!=44100):orig_samples = torchaudio.functional.resample(orig_samples, fs, 44100)
36
- # resampled_audios = orig_samples[[0],int(4.64*44100):int(35.36*48000)+480].clamp(-1,1)
37
- resampled_audios = orig_samples[[0],0:int(duration*2*44100)+480].clamp(-1,1)
38
- orig_samples = orig_samples[[0],0:int(duration*2*44100)]
39
-
40
- audio = self.vae(orig_samples[:,None,:])[:,0,:]
41
-
42
- if(orig_samples.shape[-1]<audio.shape[-1]):
43
- orig_samples = torch.cat([orig_samples, torch.zeros(orig_samples.shape[0], audio.shape[-1]-orig_samples.shape[-1], dtype=orig_samples.dtype, device=orig_samples.device)],-1)
44
- else:
45
- orig_samples = orig_samples[:,0:audio.shape[-1]]
46
- output = torch.cat([orig_samples.detach().cpu(),audio.detach().cpu()],0)
47
- return output
 
1
+ import json
2
+ import torch
3
+ from tqdm import tqdm
4
+ import torchaudio
5
+ import librosa
6
+ import os
7
+ import math
8
+ import numpy as np
9
+ from tools.get_bsrnnvae import get_bsrnnvae
10
+ import tools.torch_tools as torch_tools
11
+
12
+ class Tango:
13
+ def __init__(self, \
14
+ device="cuda:0"):
15
+
16
+ self.sample_rate = 44100
17
+ self.device = device
18
+
19
+ self.vae = get_bsrnnvae()
20
+ self.vae = self.vae.eval().to(device)
21
+
22
+ def sound2sound_generate_longterm(self, fname, batch_size=1, duration=20.48, steps=200, disable_progress=False):
23
+ """ Genrate audio without condition. """
24
+ num_frames = math.ceil(duration * 100. / 8)
25
+ with torch.no_grad():
26
+ orig_samples, fs = torchaudio.load(fname)
27
+ if(fs!=44100):
28
+ orig_samples = torchaudio.functional.resample(orig_samples, fs, 44100)
29
+ fs = 44100
30
+ if(orig_samples.shape[-1]<int(duration*44100*2)):
31
+ orig_samples = torch.cat([orig_samples, torch.zeros(orig_samples.shape[0], int(duration*44100*2+480)-orig_samples.shape[-1], \
32
+ dtype=orig_samples.dtype, device=orig_samples.device)], -1)
33
+ # orig_samples = torch.cat([torch.zeros(orig_samples.shape[0], int(duration * fs)//2, dtype=orig_samples.dtype, device=orig_samples.device), orig_samples, torch.zeros(orig_samples.shape[0], int(duration * fs)//2, dtype=orig_samples.dtype, device=orig_samples.device)], -1).to(self.device)
34
+ orig_samples = torch.cat([orig_samples, torch.zeros(orig_samples.shape[0], int(duration * fs)//2, dtype=orig_samples.dtype, device=orig_samples.device)], -1).to(self.device)
35
+ if(fs!=44100):orig_samples = torchaudio.functional.resample(orig_samples, fs, 44100)
36
+ # resampled_audios = orig_samples[[0],int(4.64*44100):int(35.36*48000)+480].clamp(-1,1)
37
+ resampled_audios = orig_samples[[0],0:int(duration*2*44100)+480].clamp(-1,1)
38
+ orig_samples = orig_samples[[0],0:int(duration*2*44100)]
39
+
40
+ audio = self.vae(orig_samples[:,None,:])[:,0,:]
41
+
42
+ if(orig_samples.shape[-1]<audio.shape[-1]):
43
+ orig_samples = torch.cat([orig_samples, torch.zeros(orig_samples.shape[0], audio.shape[-1]-orig_samples.shape[-1], dtype=orig_samples.dtype, device=orig_samples.device)],-1)
44
+ else:
45
+ orig_samples = orig_samples[:,0:audio.shape[-1]]
46
+ output = torch.cat([orig_samples.detach().cpu(),audio.detach().cpu()],0)
47
+ return output
codeclm/tokenizer/Flow1dVAE/tools/infer_hifigan48k_speech.py CHANGED
@@ -1,56 +1,56 @@
1
- import json
2
- import torch
3
- from tqdm import tqdm
4
- import torchaudio
5
- import librosa
6
- import os
7
- import math
8
- import numpy as np
9
- from get_melvaehifigan48k import build_pretrained_models
10
- import tools.torch_tools as torch_tools
11
-
12
- class Tango:
13
- def __init__(self, \
14
- device="cuda:0"):
15
-
16
- self.sample_rate = 48000
17
- self.device = device
18
-
19
- self.vae, self.stft = build_pretrained_models()
20
- self.vae, self.stft = self.vae.eval().to(device), self.stft.eval().to(device)
21
-
22
- def mel_spectrogram_to_waveform(self, mel_spectrogram):
23
- if mel_spectrogram.dim() == 4:
24
- mel_spectrogram = mel_spectrogram.squeeze(1)
25
-
26
- waveform = self.vocoder(mel_spectrogram)
27
- # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
28
- waveform = waveform.cpu().float()
29
- return waveform
30
-
31
- def sound2sound_generate_longterm(self, fname, batch_size=1, duration=10.24, steps=200, disable_progress=False):
32
- """ Genrate audio without condition. """
33
- num_frames = math.ceil(duration * 100. / 8)
34
- with torch.no_grad():
35
- orig_samples, fs = torchaudio.load(fname)
36
- if(orig_samples.shape[-1]<int(duration*48000)):
37
- orig_samples = orig_samples.repeat(1,math.ceil(int(duration*48000)/float(orig_samples.shape[-1])))
38
- # orig_samples = torch.cat([torch.zeros(orig_samples.shape[0], int(duration * fs)//2, dtype=orig_samples.dtype, device=orig_samples.device), orig_samples, torch.zeros(orig_samples.shape[0], int(duration * fs)//2, dtype=orig_samples.dtype, device=orig_samples.device)], -1).to(self.device)
39
- orig_samples = torch.cat([orig_samples, torch.zeros(orig_samples.shape[0], int(duration * fs)//2, dtype=orig_samples.dtype, device=orig_samples.device)], -1).to(self.device)
40
- if(fs!=48000):orig_samples = torchaudio.functional.resample(orig_samples, fs, 48000)
41
- # resampled_audios = orig_samples[[0],int(4.64*48000):int(35.36*48000)+480].clamp(-1,1)
42
- resampled_audios = orig_samples[[0],0:int(duration*48000)+480].clamp(-1,1)
43
- orig_samples = orig_samples[[0],0:int(duration*48000)]
44
-
45
- mel, _, _ = torch_tools.wav_to_fbank2(resampled_audios, -1, fn_STFT=self.stft)
46
- mel = mel.unsqueeze(1).to(self.device)
47
-
48
- audio = self.vae.decode_to_waveform(mel)
49
- audio = torch.from_numpy(audio)
50
-
51
- if(orig_samples.shape[-1]<audio.shape[-1]):
52
- orig_samples = torch.cat([orig_samples, torch.zeros(orig_samples.shape[0], audio.shape[-1]-orig_samples.shape[-1], dtype=orig_samples.dtype, device=orig_samples.device)],-1)
53
- else:
54
- orig_samples = orig_samples[:,0:audio.shape[-1]]
55
- output = torch.cat([orig_samples.detach().cpu(),audio.detach().cpu()],0)
56
- return output
 
1
+ import json
2
+ import torch
3
+ from tqdm import tqdm
4
+ import torchaudio
5
+ import librosa
6
+ import os
7
+ import math
8
+ import numpy as np
9
+ from get_melvaehifigan48k import build_pretrained_models
10
+ import tools.torch_tools as torch_tools
11
+
12
+ class Tango:
13
+ def __init__(self, \
14
+ device="cuda:0"):
15
+
16
+ self.sample_rate = 48000
17
+ self.device = device
18
+
19
+ self.vae, self.stft = build_pretrained_models()
20
+ self.vae, self.stft = self.vae.eval().to(device), self.stft.eval().to(device)
21
+
22
+ def mel_spectrogram_to_waveform(self, mel_spectrogram):
23
+ if mel_spectrogram.dim() == 4:
24
+ mel_spectrogram = mel_spectrogram.squeeze(1)
25
+
26
+ waveform = self.vocoder(mel_spectrogram)
27
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
28
+ waveform = waveform.cpu().float()
29
+ return waveform
30
+
31
+ def sound2sound_generate_longterm(self, fname, batch_size=1, duration=10.24, steps=200, disable_progress=False):
32
+ """ Genrate audio without condition. """
33
+ num_frames = math.ceil(duration * 100. / 8)
34
+ with torch.no_grad():
35
+ orig_samples, fs = torchaudio.load(fname)
36
+ if(orig_samples.shape[-1]<int(duration*48000)):
37
+ orig_samples = orig_samples.repeat(1,math.ceil(int(duration*48000)/float(orig_samples.shape[-1])))
38
+ # orig_samples = torch.cat([torch.zeros(orig_samples.shape[0], int(duration * fs)//2, dtype=orig_samples.dtype, device=orig_samples.device), orig_samples, torch.zeros(orig_samples.shape[0], int(duration * fs)//2, dtype=orig_samples.dtype, device=orig_samples.device)], -1).to(self.device)
39
+ orig_samples = torch.cat([orig_samples, torch.zeros(orig_samples.shape[0], int(duration * fs)//2, dtype=orig_samples.dtype, device=orig_samples.device)], -1).to(self.device)
40
+ if(fs!=48000):orig_samples = torchaudio.functional.resample(orig_samples, fs, 48000)
41
+ # resampled_audios = orig_samples[[0],int(4.64*48000):int(35.36*48000)+480].clamp(-1,1)
42
+ resampled_audios = orig_samples[[0],0:int(duration*48000)+480].clamp(-1,1)
43
+ orig_samples = orig_samples[[0],0:int(duration*48000)]
44
+
45
+ mel, _, _ = torch_tools.wav_to_fbank2(resampled_audios, -1, fn_STFT=self.stft)
46
+ mel = mel.unsqueeze(1).to(self.device)
47
+
48
+ audio = self.vae.decode_to_waveform(mel)
49
+ audio = torch.from_numpy(audio)
50
+
51
+ if(orig_samples.shape[-1]<audio.shape[-1]):
52
+ orig_samples = torch.cat([orig_samples, torch.zeros(orig_samples.shape[0], audio.shape[-1]-orig_samples.shape[-1], dtype=orig_samples.dtype, device=orig_samples.device)],-1)
53
+ else:
54
+ orig_samples = orig_samples[:,0:audio.shape[-1]]
55
+ output = torch.cat([orig_samples.detach().cpu(),audio.detach().cpu()],0)
56
+ return output
codeclm/tokenizer/Flow1dVAE/tools/infer_hifigan48k_vocal.py CHANGED
@@ -1,57 +1,57 @@
1
- import json
2
- import torch
3
- from tqdm import tqdm
4
- import torchaudio
5
- import librosa
6
- import os
7
- import math
8
- import numpy as np
9
- from get_melvaehifigan48k import build_pretrained_models
10
- import tools.torch_tools as torch_tools
11
-
12
- class Tango:
13
- def __init__(self, \
14
- device="cuda:0"):
15
-
16
- self.sample_rate = 48000
17
- self.device = device
18
-
19
- self.vae, self.stft = build_pretrained_models()
20
- self.vae, self.stft = self.vae.eval().to(device), self.stft.eval().to(device)
21
-
22
- def mel_spectrogram_to_waveform(self, mel_spectrogram):
23
- if mel_spectrogram.dim() == 4:
24
- mel_spectrogram = mel_spectrogram.squeeze(1)
25
-
26
- waveform = self.vocoder(mel_spectrogram)
27
- # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
28
- waveform = waveform.cpu().float()
29
- return waveform
30
-
31
- def sound2sound_generate_longterm(self, fname, batch_size=1, duration=10.24, steps=200, disable_progress=False):
32
- """ Genrate audio without condition. """
33
- num_frames = math.ceil(duration * 100. / 8)
34
- with torch.no_grad():
35
- orig_samples, fs = torchaudio.load(fname)
36
- if(orig_samples.shape[-1]<int(duration*48000*2)):
37
- orig_samples = torch.cat([orig_samples, torch.zeros(orig_samples.shape[0], int(duration*48000*2+480)-orig_samples.shape[-1], \
38
- dtype=orig_samples.dtype, device=orig_samples.device)], -1)
39
- # orig_samples = torch.cat([torch.zeros(orig_samples.shape[0], int(duration * fs)//2, dtype=orig_samples.dtype, device=orig_samples.device), orig_samples, torch.zeros(orig_samples.shape[0], int(duration * fs)//2, dtype=orig_samples.dtype, device=orig_samples.device)], -1).to(self.device)
40
- orig_samples = torch.cat([orig_samples, torch.zeros(orig_samples.shape[0], int(duration * fs)//2, dtype=orig_samples.dtype, device=orig_samples.device)], -1).to(self.device)
41
- if(fs!=48000):orig_samples = torchaudio.functional.resample(orig_samples, fs, 48000)
42
- # resampled_audios = orig_samples[[0],int(4.64*48000):int(35.36*48000)+480].clamp(-1,1)
43
- resampled_audios = orig_samples[[0],0:int(duration*2*48000)+480].clamp(-1,1)
44
- orig_samples = orig_samples[[0],0:int(duration*2*48000)]
45
-
46
- mel, _, _ = torch_tools.wav_to_fbank2(resampled_audios, -1, fn_STFT=self.stft)
47
- mel = mel.unsqueeze(1).to(self.device)
48
-
49
- audio = self.vae.decode_to_waveform(mel)
50
- audio = torch.from_numpy(audio)
51
-
52
- if(orig_samples.shape[-1]<audio.shape[-1]):
53
- orig_samples = torch.cat([orig_samples, torch.zeros(orig_samples.shape[0], audio.shape[-1]-orig_samples.shape[-1], dtype=orig_samples.dtype, device=orig_samples.device)],-1)
54
- else:
55
- orig_samples = orig_samples[:,0:audio.shape[-1]]
56
- output = torch.cat([orig_samples.detach().cpu(),audio.detach().cpu()],0)
57
- return output
 
1
+ import json
2
+ import torch
3
+ from tqdm import tqdm
4
+ import torchaudio
5
+ import librosa
6
+ import os
7
+ import math
8
+ import numpy as np
9
+ from get_melvaehifigan48k import build_pretrained_models
10
+ import tools.torch_tools as torch_tools
11
+
12
+ class Tango:
13
+ def __init__(self, \
14
+ device="cuda:0"):
15
+
16
+ self.sample_rate = 48000
17
+ self.device = device
18
+
19
+ self.vae, self.stft = build_pretrained_models()
20
+ self.vae, self.stft = self.vae.eval().to(device), self.stft.eval().to(device)
21
+
22
+ def mel_spectrogram_to_waveform(self, mel_spectrogram):
23
+ if mel_spectrogram.dim() == 4:
24
+ mel_spectrogram = mel_spectrogram.squeeze(1)
25
+
26
+ waveform = self.vocoder(mel_spectrogram)
27
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
28
+ waveform = waveform.cpu().float()
29
+ return waveform
30
+
31
+ def sound2sound_generate_longterm(self, fname, batch_size=1, duration=10.24, steps=200, disable_progress=False):
32
+ """ Genrate audio without condition. """
33
+ num_frames = math.ceil(duration * 100. / 8)
34
+ with torch.no_grad():
35
+ orig_samples, fs = torchaudio.load(fname)
36
+ if(orig_samples.shape[-1]<int(duration*48000*2)):
37
+ orig_samples = torch.cat([orig_samples, torch.zeros(orig_samples.shape[0], int(duration*48000*2+480)-orig_samples.shape[-1], \
38
+ dtype=orig_samples.dtype, device=orig_samples.device)], -1)
39
+ # orig_samples = torch.cat([torch.zeros(orig_samples.shape[0], int(duration * fs)//2, dtype=orig_samples.dtype, device=orig_samples.device), orig_samples, torch.zeros(orig_samples.shape[0], int(duration * fs)//2, dtype=orig_samples.dtype, device=orig_samples.device)], -1).to(self.device)
40
+ orig_samples = torch.cat([orig_samples, torch.zeros(orig_samples.shape[0], int(duration * fs)//2, dtype=orig_samples.dtype, device=orig_samples.device)], -1).to(self.device)
41
+ if(fs!=48000):orig_samples = torchaudio.functional.resample(orig_samples, fs, 48000)
42
+ # resampled_audios = orig_samples[[0],int(4.64*48000):int(35.36*48000)+480].clamp(-1,1)
43
+ resampled_audios = orig_samples[[0],0:int(duration*2*48000)+480].clamp(-1,1)
44
+ orig_samples = orig_samples[[0],0:int(duration*2*48000)]
45
+
46
+ mel, _, _ = torch_tools.wav_to_fbank2(resampled_audios, -1, fn_STFT=self.stft)
47
+ mel = mel.unsqueeze(1).to(self.device)
48
+
49
+ audio = self.vae.decode_to_waveform(mel)
50
+ audio = torch.from_numpy(audio)
51
+
52
+ if(orig_samples.shape[-1]<audio.shape[-1]):
53
+ orig_samples = torch.cat([orig_samples, torch.zeros(orig_samples.shape[0], audio.shape[-1]-orig_samples.shape[-1], dtype=orig_samples.dtype, device=orig_samples.device)],-1)
54
+ else:
55
+ orig_samples = orig_samples[:,0:audio.shape[-1]]
56
+ output = torch.cat([orig_samples.detach().cpu(),audio.detach().cpu()],0)
57
+ return output
codeclm/tokenizer/Flow1dVAE/tools/infer_vaehifigan48k.py CHANGED
@@ -1,59 +1,59 @@
1
- import json
2
- import torch
3
- from tqdm import tqdm
4
- import torchaudio
5
- import librosa
6
- import os
7
- import math
8
- import numpy as np
9
- from get_melvaehifigan48k import build_pretrained_models
10
- import tools.torch_tools as torch_tools
11
-
12
- class Tango:
13
- def __init__(self, \
14
- device="cuda:0"):
15
-
16
- self.sample_rate = 48000
17
- self.device = device
18
-
19
- self.vae, self.stft = build_pretrained_models()
20
- self.vae, self.stft = self.vae.eval().to(device), self.stft.eval().to(device)
21
-
22
- def mel_spectrogram_to_waveform(self, mel_spectrogram):
23
- if mel_spectrogram.dim() == 4:
24
- mel_spectrogram = mel_spectrogram.squeeze(1)
25
-
26
- waveform = self.vocoder(mel_spectrogram)
27
- # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
28
- waveform = waveform.cpu().float()
29
- return waveform
30
-
31
- def sound2sound_generate_longterm(self, fname, batch_size=1, duration=10.24, steps=200, disable_progress=False):
32
- """ Genrate audio without condition. """
33
- num_frames = math.ceil(duration * 100. / 8)
34
- with torch.no_grad():
35
- orig_samples, fs = torchaudio.load(fname)
36
- if(orig_samples.shape[-1]<int(duration*48000*3)):
37
- orig_samples = orig_samples.repeat(1,math.ceil(int(duration*48000*3)/float(orig_samples.shape[-1])))
38
- # orig_samples = torch.cat([torch.zeros(orig_samples.shape[0], int(duration * fs)//2, dtype=orig_samples.dtype, device=orig_samples.device), orig_samples, torch.zeros(orig_samples.shape[0], int(duration * fs)//2, dtype=orig_samples.dtype, device=orig_samples.device)], -1).to(self.device)
39
- orig_samples = torch.cat([orig_samples, torch.zeros(orig_samples.shape[0], int(duration * fs)//2, dtype=orig_samples.dtype, device=orig_samples.device)], -1).to(self.device)
40
- if(fs!=48000):orig_samples = torchaudio.functional.resample(orig_samples, fs, 48000)
41
- # resampled_audios = orig_samples[[0],int(4.64*48000):int(35.36*48000)+480].clamp(-1,1)
42
- resampled_audios = orig_samples[[0],int(0*48000):int(duration*3*48000)+480].clamp(-1,1)
43
- orig_samples = orig_samples[[0],:]
44
-
45
- mel, _, _ = torch_tools.wav_to_fbank2(resampled_audios, -1, fn_STFT=self.stft)
46
- mel = mel.unsqueeze(1).to(self.device)
47
- latents = torch.cat([self.vae.get_first_stage_encoding(self.vae.encode_first_stage(mel[[m]])) for m in range(mel.shape[0])],0)
48
-
49
- mel = self.vae.decode_first_stage(latents)
50
- audio = self.vae.decode_to_waveform(mel)
51
- audio = torch.from_numpy(audio)
52
-
53
- orig_samples = orig_samples[...,0:int(duration * 3 * 48000)]
54
- if(orig_samples.shape[-1]<audio.shape[-1]):
55
- orig_samples = torch.cat([orig_samples, torch.zeros(orig_samples.shape[0], audio.shape[-1]-orig_samples.shape[-1], dtype=orig_samples.dtype, device=orig_samples.device)],-1)
56
- else:
57
- orig_samples = orig_samples[:,0:audio.shape[-1]]
58
- output = torch.cat([orig_samples.detach().cpu(),audio.detach().cpu()],0)
59
- return output
 
1
+ import json
2
+ import torch
3
+ from tqdm import tqdm
4
+ import torchaudio
5
+ import librosa
6
+ import os
7
+ import math
8
+ import numpy as np
9
+ from get_melvaehifigan48k import build_pretrained_models
10
+ import tools.torch_tools as torch_tools
11
+
12
+ class Tango:
13
+ def __init__(self, \
14
+ device="cuda:0"):
15
+
16
+ self.sample_rate = 48000
17
+ self.device = device
18
+
19
+ self.vae, self.stft = build_pretrained_models()
20
+ self.vae, self.stft = self.vae.eval().to(device), self.stft.eval().to(device)
21
+
22
+ def mel_spectrogram_to_waveform(self, mel_spectrogram):
23
+ if mel_spectrogram.dim() == 4:
24
+ mel_spectrogram = mel_spectrogram.squeeze(1)
25
+
26
+ waveform = self.vocoder(mel_spectrogram)
27
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
28
+ waveform = waveform.cpu().float()
29
+ return waveform
30
+
31
+ def sound2sound_generate_longterm(self, fname, batch_size=1, duration=10.24, steps=200, disable_progress=False):
32
+ """ Genrate audio without condition. """
33
+ num_frames = math.ceil(duration * 100. / 8)
34
+ with torch.no_grad():
35
+ orig_samples, fs = torchaudio.load(fname)
36
+ if(orig_samples.shape[-1]<int(duration*48000*3)):
37
+ orig_samples = orig_samples.repeat(1,math.ceil(int(duration*48000*3)/float(orig_samples.shape[-1])))
38
+ # orig_samples = torch.cat([torch.zeros(orig_samples.shape[0], int(duration * fs)//2, dtype=orig_samples.dtype, device=orig_samples.device), orig_samples, torch.zeros(orig_samples.shape[0], int(duration * fs)//2, dtype=orig_samples.dtype, device=orig_samples.device)], -1).to(self.device)
39
+ orig_samples = torch.cat([orig_samples, torch.zeros(orig_samples.shape[0], int(duration * fs)//2, dtype=orig_samples.dtype, device=orig_samples.device)], -1).to(self.device)
40
+ if(fs!=48000):orig_samples = torchaudio.functional.resample(orig_samples, fs, 48000)
41
+ # resampled_audios = orig_samples[[0],int(4.64*48000):int(35.36*48000)+480].clamp(-1,1)
42
+ resampled_audios = orig_samples[[0],int(0*48000):int(duration*3*48000)+480].clamp(-1,1)
43
+ orig_samples = orig_samples[[0],:]
44
+
45
+ mel, _, _ = torch_tools.wav_to_fbank2(resampled_audios, -1, fn_STFT=self.stft)
46
+ mel = mel.unsqueeze(1).to(self.device)
47
+ latents = torch.cat([self.vae.get_first_stage_encoding(self.vae.encode_first_stage(mel[[m]])) for m in range(mel.shape[0])],0)
48
+
49
+ mel = self.vae.decode_first_stage(latents)
50
+ audio = self.vae.decode_to_waveform(mel)
51
+ audio = torch.from_numpy(audio)
52
+
53
+ orig_samples = orig_samples[...,0:int(duration * 3 * 48000)]
54
+ if(orig_samples.shape[-1]<audio.shape[-1]):
55
+ orig_samples = torch.cat([orig_samples, torch.zeros(orig_samples.shape[0], audio.shape[-1]-orig_samples.shape[-1], dtype=orig_samples.dtype, device=orig_samples.device)],-1)
56
+ else:
57
+ orig_samples = orig_samples[:,0:audio.shape[-1]]
58
+ output = torch.cat([orig_samples.detach().cpu(),audio.detach().cpu()],0)
59
+ return output
codeclm/tokenizer/Flow1dVAE/tools/infer_vaehifigan48k_soundmusic.py CHANGED
@@ -1,61 +1,61 @@
1
- import json
2
- import torch
3
- from tqdm import tqdm
4
- import torchaudio
5
- import librosa
6
- import os
7
- import math
8
- import numpy as np
9
- from get_melvaehifigan48k import build_pretrained_models
10
- import tools.torch_tools as torch_tools
11
-
12
- class Tango:
13
- def __init__(self, \
14
- device="cuda:0"):
15
-
16
- self.sample_rate = 48000
17
- self.device = device
18
-
19
- self.vae, self.stft = build_pretrained_models()
20
- self.vae, self.stft = self.vae.eval().to(device), self.stft.eval().to(device)
21
-
22
- # print(sum(p.numel() for p in self.vae.parameters()));exit()
23
-
24
- def mel_spectrogram_to_waveform(self, mel_spectrogram):
25
- if mel_spectrogram.dim() == 4:
26
- mel_spectrogram = mel_spectrogram.squeeze(1)
27
-
28
- waveform = self.vocoder(mel_spectrogram)
29
- # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
30
- waveform = waveform.cpu().float()
31
- return waveform
32
-
33
- def sound2sound_generate_longterm(self, fname, batch_size=1, duration=10.24, steps=200, disable_progress=False):
34
- """ Genrate audio without condition. """
35
- num_frames = math.ceil(duration * 100. / 8)
36
- with torch.no_grad():
37
- orig_samples, fs = torchaudio.load(fname)
38
- if(orig_samples.shape[-1]<int(duration*48000)):
39
- orig_samples = orig_samples.repeat(1,math.ceil(int(duration*48000)/float(orig_samples.shape[-1])))
40
- # orig_samples = torch.cat([torch.zeros(orig_samples.shape[0], int(duration * fs)//2, dtype=orig_samples.dtype, device=orig_samples.device), orig_samples, torch.zeros(orig_samples.shape[0], int(duration * fs)//2, dtype=orig_samples.dtype, device=orig_samples.device)], -1).to(self.device)
41
- orig_samples = torch.cat([orig_samples, torch.zeros(orig_samples.shape[0], int(duration * fs)//2, dtype=orig_samples.dtype, device=orig_samples.device)], -1).to(self.device)
42
- if(fs!=48000):orig_samples = torchaudio.functional.resample(orig_samples, fs, 48000)
43
- # resampled_audios = orig_samples[[0],int(4.64*48000):int(35.36*48000)+480].clamp(-1,1)
44
- resampled_audios = orig_samples[[0],int(0*48000):int(duration*48000)+480].clamp(-1,1)
45
- orig_samples = orig_samples[[0],:]
46
-
47
- mel, _, _ = torch_tools.wav_to_fbank2(resampled_audios, -1, fn_STFT=self.stft)
48
- mel = mel.unsqueeze(1).to(self.device)
49
- latents = torch.cat([self.vae.get_first_stage_encoding(self.vae.encode_first_stage(mel[[m]])) for m in range(mel.shape[0])],0)
50
-
51
- mel = self.vae.decode_first_stage(latents)
52
- audio = self.vae.decode_to_waveform(mel)
53
- audio = torch.from_numpy(audio)
54
-
55
- orig_samples = orig_samples[...,0:int(duration * 48000)]
56
- if(orig_samples.shape[-1]<audio.shape[-1]):
57
- orig_samples = torch.cat([orig_samples, torch.zeros(orig_samples.shape[0], audio.shape[-1]-orig_samples.shape[-1], dtype=orig_samples.dtype, device=orig_samples.device)],-1)
58
- else:
59
- orig_samples = orig_samples[:,0:audio.shape[-1]]
60
- output = torch.cat([orig_samples.detach().cpu(),audio.detach().cpu()],0)
61
- return output
 
1
+ import json
2
+ import torch
3
+ from tqdm import tqdm
4
+ import torchaudio
5
+ import librosa
6
+ import os
7
+ import math
8
+ import numpy as np
9
+ from get_melvaehifigan48k import build_pretrained_models
10
+ import tools.torch_tools as torch_tools
11
+
12
+ class Tango:
13
+ def __init__(self, \
14
+ device="cuda:0"):
15
+
16
+ self.sample_rate = 48000
17
+ self.device = device
18
+
19
+ self.vae, self.stft = build_pretrained_models()
20
+ self.vae, self.stft = self.vae.eval().to(device), self.stft.eval().to(device)
21
+
22
+ # print(sum(p.numel() for p in self.vae.parameters()));exit()
23
+
24
+ def mel_spectrogram_to_waveform(self, mel_spectrogram):
25
+ if mel_spectrogram.dim() == 4:
26
+ mel_spectrogram = mel_spectrogram.squeeze(1)
27
+
28
+ waveform = self.vocoder(mel_spectrogram)
29
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
30
+ waveform = waveform.cpu().float()
31
+ return waveform
32
+
33
+ def sound2sound_generate_longterm(self, fname, batch_size=1, duration=10.24, steps=200, disable_progress=False):
34
+ """ Genrate audio without condition. """
35
+ num_frames = math.ceil(duration * 100. / 8)
36
+ with torch.no_grad():
37
+ orig_samples, fs = torchaudio.load(fname)
38
+ if(orig_samples.shape[-1]<int(duration*48000)):
39
+ orig_samples = orig_samples.repeat(1,math.ceil(int(duration*48000)/float(orig_samples.shape[-1])))
40
+ # orig_samples = torch.cat([torch.zeros(orig_samples.shape[0], int(duration * fs)//2, dtype=orig_samples.dtype, device=orig_samples.device), orig_samples, torch.zeros(orig_samples.shape[0], int(duration * fs)//2, dtype=orig_samples.dtype, device=orig_samples.device)], -1).to(self.device)
41
+ orig_samples = torch.cat([orig_samples, torch.zeros(orig_samples.shape[0], int(duration * fs)//2, dtype=orig_samples.dtype, device=orig_samples.device)], -1).to(self.device)
42
+ if(fs!=48000):orig_samples = torchaudio.functional.resample(orig_samples, fs, 48000)
43
+ # resampled_audios = orig_samples[[0],int(4.64*48000):int(35.36*48000)+480].clamp(-1,1)
44
+ resampled_audios = orig_samples[[0],int(0*48000):int(duration*48000)+480].clamp(-1,1)
45
+ orig_samples = orig_samples[[0],:]
46
+
47
+ mel, _, _ = torch_tools.wav_to_fbank2(resampled_audios, -1, fn_STFT=self.stft)
48
+ mel = mel.unsqueeze(1).to(self.device)
49
+ latents = torch.cat([self.vae.get_first_stage_encoding(self.vae.encode_first_stage(mel[[m]])) for m in range(mel.shape[0])],0)
50
+
51
+ mel = self.vae.decode_first_stage(latents)
52
+ audio = self.vae.decode_to_waveform(mel)
53
+ audio = torch.from_numpy(audio)
54
+
55
+ orig_samples = orig_samples[...,0:int(duration * 48000)]
56
+ if(orig_samples.shape[-1]<audio.shape[-1]):
57
+ orig_samples = torch.cat([orig_samples, torch.zeros(orig_samples.shape[0], audio.shape[-1]-orig_samples.shape[-1], dtype=orig_samples.dtype, device=orig_samples.device)],-1)
58
+ else:
59
+ orig_samples = orig_samples[:,0:audio.shape[-1]]
60
+ output = torch.cat([orig_samples.detach().cpu(),audio.detach().cpu()],0)
61
+ return output
codeclm/tokenizer/Flow1dVAE/tools/infer_vaehifigan48k_speech.py CHANGED
@@ -1,58 +1,58 @@
1
- import json
2
- import torch
3
- from tqdm import tqdm
4
- import torchaudio
5
- import librosa
6
- import os
7
- import math
8
- import numpy as np
9
- from get_melvaehifigan48k import build_pretrained_models
10
- import tools.torch_tools as torch_tools
11
-
12
- class Tango:
13
- def __init__(self, \
14
- device="cuda:0"):
15
-
16
- self.sample_rate = 48000
17
- self.device = device
18
-
19
- self.vae, self.stft = build_pretrained_models()
20
- self.vae, self.stft = self.vae.eval().to(device), self.stft.eval().to(device)
21
-
22
- def mel_spectrogram_to_waveform(self, mel_spectrogram):
23
- if mel_spectrogram.dim() == 4:
24
- mel_spectrogram = mel_spectrogram.squeeze(1)
25
-
26
- waveform = self.vocoder(mel_spectrogram)
27
- # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
28
- waveform = waveform.cpu().float()
29
- return waveform
30
-
31
- def sound2sound_generate_longterm(self, fname, batch_size=1, duration=10.24, steps=200, disable_progress=False):
32
- """ Genrate audio without condition. """
33
- num_frames = math.ceil(duration * 100. / 8)
34
- with torch.no_grad():
35
- orig_samples, fs = torchaudio.load(fname)
36
- if(orig_samples.shape[-1]<int(duration*48000)):
37
- orig_samples = orig_samples.repeat(1,math.ceil(int(duration*48000)/float(orig_samples.shape[-1])))
38
- # orig_samples = torch.cat([torch.zeros(orig_samples.shape[0], int(duration * fs)//2, dtype=orig_samples.dtype, device=orig_samples.device), orig_samples, torch.zeros(orig_samples.shape[0], int(duration * fs)//2, dtype=orig_samples.dtype, device=orig_samples.device)], -1).to(self.device)
39
- orig_samples = torch.cat([orig_samples, torch.zeros(orig_samples.shape[0], int(duration * fs)//2, dtype=orig_samples.dtype, device=orig_samples.device)], -1).to(self.device)
40
- if(fs!=48000):orig_samples = torchaudio.functional.resample(orig_samples, fs, 48000)
41
- # resampled_audios = orig_samples[[0],int(4.64*48000):int(35.36*48000)+480].clamp(-1,1)
42
- resampled_audios = orig_samples[[0],0:int(duration*48000)+480].clamp(-1,1)
43
- orig_samples = orig_samples[[0],0:int(duration*48000)]
44
-
45
- mel, _, _ = torch_tools.wav_to_fbank2(resampled_audios, -1, fn_STFT=self.stft)
46
- mel = mel.unsqueeze(1).to(self.device)
47
- latents = torch.cat([self.vae.get_first_stage_encoding(self.vae.encode_first_stage(mel[[m]])) for m in range(mel.shape[0])],0)
48
-
49
- mel = self.vae.decode_first_stage(latents)
50
- audio = self.vae.decode_to_waveform(mel)
51
- audio = torch.from_numpy(audio)
52
-
53
- if(orig_samples.shape[-1]<audio.shape[-1]):
54
- orig_samples = torch.cat([orig_samples, torch.zeros(orig_samples.shape[0], audio.shape[-1]-orig_samples.shape[-1], dtype=orig_samples.dtype, device=orig_samples.device)],-1)
55
- else:
56
- orig_samples = orig_samples[:,0:audio.shape[-1]]
57
- output = torch.cat([orig_samples.detach().cpu(),audio.detach().cpu()],0)
58
- return output
 
1
+ import json
2
+ import torch
3
+ from tqdm import tqdm
4
+ import torchaudio
5
+ import librosa
6
+ import os
7
+ import math
8
+ import numpy as np
9
+ from get_melvaehifigan48k import build_pretrained_models
10
+ import tools.torch_tools as torch_tools
11
+
12
+ class Tango:
13
+ def __init__(self, \
14
+ device="cuda:0"):
15
+
16
+ self.sample_rate = 48000
17
+ self.device = device
18
+
19
+ self.vae, self.stft = build_pretrained_models()
20
+ self.vae, self.stft = self.vae.eval().to(device), self.stft.eval().to(device)
21
+
22
+ def mel_spectrogram_to_waveform(self, mel_spectrogram):
23
+ if mel_spectrogram.dim() == 4:
24
+ mel_spectrogram = mel_spectrogram.squeeze(1)
25
+
26
+ waveform = self.vocoder(mel_spectrogram)
27
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
28
+ waveform = waveform.cpu().float()
29
+ return waveform
30
+
31
+ def sound2sound_generate_longterm(self, fname, batch_size=1, duration=10.24, steps=200, disable_progress=False):
32
+ """ Genrate audio without condition. """
33
+ num_frames = math.ceil(duration * 100. / 8)
34
+ with torch.no_grad():
35
+ orig_samples, fs = torchaudio.load(fname)
36
+ if(orig_samples.shape[-1]<int(duration*48000)):
37
+ orig_samples = orig_samples.repeat(1,math.ceil(int(duration*48000)/float(orig_samples.shape[-1])))
38
+ # orig_samples = torch.cat([torch.zeros(orig_samples.shape[0], int(duration * fs)//2, dtype=orig_samples.dtype, device=orig_samples.device), orig_samples, torch.zeros(orig_samples.shape[0], int(duration * fs)//2, dtype=orig_samples.dtype, device=orig_samples.device)], -1).to(self.device)
39
+ orig_samples = torch.cat([orig_samples, torch.zeros(orig_samples.shape[0], int(duration * fs)//2, dtype=orig_samples.dtype, device=orig_samples.device)], -1).to(self.device)
40
+ if(fs!=48000):orig_samples = torchaudio.functional.resample(orig_samples, fs, 48000)
41
+ # resampled_audios = orig_samples[[0],int(4.64*48000):int(35.36*48000)+480].clamp(-1,1)
42
+ resampled_audios = orig_samples[[0],0:int(duration*48000)+480].clamp(-1,1)
43
+ orig_samples = orig_samples[[0],0:int(duration*48000)]
44
+
45
+ mel, _, _ = torch_tools.wav_to_fbank2(resampled_audios, -1, fn_STFT=self.stft)
46
+ mel = mel.unsqueeze(1).to(self.device)
47
+ latents = torch.cat([self.vae.get_first_stage_encoding(self.vae.encode_first_stage(mel[[m]])) for m in range(mel.shape[0])],0)
48
+
49
+ mel = self.vae.decode_first_stage(latents)
50
+ audio = self.vae.decode_to_waveform(mel)
51
+ audio = torch.from_numpy(audio)
52
+
53
+ if(orig_samples.shape[-1]<audio.shape[-1]):
54
+ orig_samples = torch.cat([orig_samples, torch.zeros(orig_samples.shape[0], audio.shape[-1]-orig_samples.shape[-1], dtype=orig_samples.dtype, device=orig_samples.device)],-1)
55
+ else:
56
+ orig_samples = orig_samples[:,0:audio.shape[-1]]
57
+ output = torch.cat([orig_samples.detach().cpu(),audio.detach().cpu()],0)
58
+ return output
codeclm/tokenizer/Flow1dVAE/tools/infer_vaehifigan48k_vocal.py CHANGED
@@ -1,59 +1,59 @@
1
- import json
2
- import torch
3
- from tqdm import tqdm
4
- import torchaudio
5
- import librosa
6
- import os
7
- import math
8
- import numpy as np
9
- from get_melvaehifigan48k import build_pretrained_models
10
- import tools.torch_tools as torch_tools
11
-
12
- class Tango:
13
- def __init__(self, \
14
- device="cuda:0"):
15
-
16
- self.sample_rate = 48000
17
- self.device = device
18
-
19
- self.vae, self.stft = build_pretrained_models()
20
- self.vae, self.stft = self.vae.eval().to(device), self.stft.eval().to(device)
21
-
22
- def mel_spectrogram_to_waveform(self, mel_spectrogram):
23
- if mel_spectrogram.dim() == 4:
24
- mel_spectrogram = mel_spectrogram.squeeze(1)
25
-
26
- waveform = self.vocoder(mel_spectrogram)
27
- # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
28
- waveform = waveform.cpu().float()
29
- return waveform
30
-
31
- def sound2sound_generate_longterm(self, fname, batch_size=1, duration=10.24, steps=200, disable_progress=False):
32
- """ Genrate audio without condition. """
33
- num_frames = math.ceil(duration * 100. / 8)
34
- with torch.no_grad():
35
- orig_samples, fs = torchaudio.load(fname)
36
- if(orig_samples.shape[-1]<int(duration*48000*2)):
37
- orig_samples = torch.cat([orig_samples, torch.zeros(orig_samples.shape[0], int(duration*48000*2+480)-orig_samples.shape[-1], \
38
- dtype=orig_samples.dtype, device=orig_samples.device)], -1)
39
- # orig_samples = torch.cat([torch.zeros(orig_samples.shape[0], int(duration * fs)//2, dtype=orig_samples.dtype, device=orig_samples.device), orig_samples, torch.zeros(orig_samples.shape[0], int(duration * fs)//2, dtype=orig_samples.dtype, device=orig_samples.device)], -1).to(self.device)
40
- orig_samples = torch.cat([orig_samples, torch.zeros(orig_samples.shape[0], int(duration * fs)//2, dtype=orig_samples.dtype, device=orig_samples.device)], -1).to(self.device)
41
- if(fs!=48000):orig_samples = torchaudio.functional.resample(orig_samples, fs, 48000)
42
- # resampled_audios = orig_samples[[0],int(4.64*48000):int(35.36*48000)+480].clamp(-1,1)
43
- resampled_audios = orig_samples[[0],0:int(duration*2*48000)+480].clamp(-1,1)
44
- orig_samples = orig_samples[[0],0:int(duration*2*48000)]
45
-
46
- mel, _, _ = torch_tools.wav_to_fbank2(resampled_audios, -1, fn_STFT=self.stft)
47
- mel = mel.unsqueeze(1).to(self.device)
48
- latents = torch.cat([self.vae.get_first_stage_encoding(self.vae.encode_first_stage(mel[[m]])) for m in range(mel.shape[0])],0)
49
-
50
- mel = self.vae.decode_first_stage(latents)
51
- audio = self.vae.decode_to_waveform(mel)
52
- audio = torch.from_numpy(audio)
53
-
54
- if(orig_samples.shape[-1]<audio.shape[-1]):
55
- orig_samples = torch.cat([orig_samples, torch.zeros(orig_samples.shape[0], audio.shape[-1]-orig_samples.shape[-1], dtype=orig_samples.dtype, device=orig_samples.device)],-1)
56
- else:
57
- orig_samples = orig_samples[:,0:audio.shape[-1]]
58
- output = torch.cat([orig_samples.detach().cpu(),audio.detach().cpu()],0)
59
- return output
 
1
+ import json
2
+ import torch
3
+ from tqdm import tqdm
4
+ import torchaudio
5
+ import librosa
6
+ import os
7
+ import math
8
+ import numpy as np
9
+ from get_melvaehifigan48k import build_pretrained_models
10
+ import tools.torch_tools as torch_tools
11
+
12
+ class Tango:
13
+ def __init__(self, \
14
+ device="cuda:0"):
15
+
16
+ self.sample_rate = 48000
17
+ self.device = device
18
+
19
+ self.vae, self.stft = build_pretrained_models()
20
+ self.vae, self.stft = self.vae.eval().to(device), self.stft.eval().to(device)
21
+
22
+ def mel_spectrogram_to_waveform(self, mel_spectrogram):
23
+ if mel_spectrogram.dim() == 4:
24
+ mel_spectrogram = mel_spectrogram.squeeze(1)
25
+
26
+ waveform = self.vocoder(mel_spectrogram)
27
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
28
+ waveform = waveform.cpu().float()
29
+ return waveform
30
+
31
+ def sound2sound_generate_longterm(self, fname, batch_size=1, duration=10.24, steps=200, disable_progress=False):
32
+ """ Genrate audio without condition. """
33
+ num_frames = math.ceil(duration * 100. / 8)
34
+ with torch.no_grad():
35
+ orig_samples, fs = torchaudio.load(fname)
36
+ if(orig_samples.shape[-1]<int(duration*48000*2)):
37
+ orig_samples = torch.cat([orig_samples, torch.zeros(orig_samples.shape[0], int(duration*48000*2+480)-orig_samples.shape[-1], \
38
+ dtype=orig_samples.dtype, device=orig_samples.device)], -1)
39
+ # orig_samples = torch.cat([torch.zeros(orig_samples.shape[0], int(duration * fs)//2, dtype=orig_samples.dtype, device=orig_samples.device), orig_samples, torch.zeros(orig_samples.shape[0], int(duration * fs)//2, dtype=orig_samples.dtype, device=orig_samples.device)], -1).to(self.device)
40
+ orig_samples = torch.cat([orig_samples, torch.zeros(orig_samples.shape[0], int(duration * fs)//2, dtype=orig_samples.dtype, device=orig_samples.device)], -1).to(self.device)
41
+ if(fs!=48000):orig_samples = torchaudio.functional.resample(orig_samples, fs, 48000)
42
+ # resampled_audios = orig_samples[[0],int(4.64*48000):int(35.36*48000)+480].clamp(-1,1)
43
+ resampled_audios = orig_samples[[0],0:int(duration*2*48000)+480].clamp(-1,1)
44
+ orig_samples = orig_samples[[0],0:int(duration*2*48000)]
45
+
46
+ mel, _, _ = torch_tools.wav_to_fbank2(resampled_audios, -1, fn_STFT=self.stft)
47
+ mel = mel.unsqueeze(1).to(self.device)
48
+ latents = torch.cat([self.vae.get_first_stage_encoding(self.vae.encode_first_stage(mel[[m]])) for m in range(mel.shape[0])],0)
49
+
50
+ mel = self.vae.decode_first_stage(latents)
51
+ audio = self.vae.decode_to_waveform(mel)
52
+ audio = torch.from_numpy(audio)
53
+
54
+ if(orig_samples.shape[-1]<audio.shape[-1]):
55
+ orig_samples = torch.cat([orig_samples, torch.zeros(orig_samples.shape[0], audio.shape[-1]-orig_samples.shape[-1], dtype=orig_samples.dtype, device=orig_samples.device)],-1)
56
+ else:
57
+ orig_samples = orig_samples[:,0:audio.shape[-1]]
58
+ output = torch.cat([orig_samples.detach().cpu(),audio.detach().cpu()],0)
59
+ return output
codeclm/tokenizer/Flow1dVAE/tools/mix.py CHANGED
@@ -1,51 +1,51 @@
1
- import numpy as np
2
-
3
-
4
- def a_weight(fs, n_fft, min_db=-80.0):
5
- freq = np.linspace(0, fs // 2, n_fft // 2 + 1)
6
- freq_sq = np.power(freq, 2)
7
- freq_sq[0] = 1.0
8
- weight = 2.0 + 20.0 * (2 * np.log10(12194) + 2 * np.log10(freq_sq)
9
- - np.log10(freq_sq + 12194 ** 2)
10
- - np.log10(freq_sq + 20.6 ** 2)
11
- - 0.5 * np.log10(freq_sq + 107.7 ** 2)
12
- - 0.5 * np.log10(freq_sq + 737.9 ** 2))
13
- weight = np.maximum(weight, min_db)
14
-
15
- return weight
16
-
17
-
18
- def compute_gain(sound, fs, min_db=-80.0, mode="A_weighting"):
19
- if fs == 16000:
20
- n_fft = 2048
21
- elif fs == 44100:
22
- n_fft = 4096
23
- else:
24
- raise Exception("Invalid fs {}".format(fs))
25
- stride = n_fft // 2
26
-
27
- gain = []
28
- for i in range(0, len(sound) - n_fft + 1, stride):
29
- if mode == "RMSE":
30
- g = np.mean(sound[i: i + n_fft] ** 2)
31
- elif mode == "A_weighting":
32
- spec = np.fft.rfft(np.hanning(n_fft + 1)[:-1] * sound[i: i + n_fft])
33
- power_spec = np.abs(spec) ** 2
34
- a_weighted_spec = power_spec * np.power(10, a_weight(fs, n_fft) / 10)
35
- g = np.sum(a_weighted_spec)
36
- else:
37
- raise Exception("Invalid mode {}".format(mode))
38
- gain.append(g)
39
-
40
- gain = np.array(gain)
41
- gain = np.maximum(gain, np.power(10, min_db / 10))
42
- gain_db = 10 * np.log10(gain)
43
- return gain_db
44
-
45
-
46
- def mix(sound1, sound2, r, fs):
47
- gain1 = np.max(compute_gain(sound1, fs)) # Decibel
48
- gain2 = np.max(compute_gain(sound2, fs))
49
- t = 1.0 / (1 + np.power(10, (gain1 - gain2) / 20.) * (1 - r) / r)
50
- sound = ((sound1 * t + sound2 * (1 - t)) / np.sqrt(t ** 2 + (1 - t) ** 2))
51
  return sound
 
1
+ import numpy as np
2
+
3
+
4
+ def a_weight(fs, n_fft, min_db=-80.0):
5
+ freq = np.linspace(0, fs // 2, n_fft // 2 + 1)
6
+ freq_sq = np.power(freq, 2)
7
+ freq_sq[0] = 1.0
8
+ weight = 2.0 + 20.0 * (2 * np.log10(12194) + 2 * np.log10(freq_sq)
9
+ - np.log10(freq_sq + 12194 ** 2)
10
+ - np.log10(freq_sq + 20.6 ** 2)
11
+ - 0.5 * np.log10(freq_sq + 107.7 ** 2)
12
+ - 0.5 * np.log10(freq_sq + 737.9 ** 2))
13
+ weight = np.maximum(weight, min_db)
14
+
15
+ return weight
16
+
17
+
18
+ def compute_gain(sound, fs, min_db=-80.0, mode="A_weighting"):
19
+ if fs == 16000:
20
+ n_fft = 2048
21
+ elif fs == 44100:
22
+ n_fft = 4096
23
+ else:
24
+ raise Exception("Invalid fs {}".format(fs))
25
+ stride = n_fft // 2
26
+
27
+ gain = []
28
+ for i in range(0, len(sound) - n_fft + 1, stride):
29
+ if mode == "RMSE":
30
+ g = np.mean(sound[i: i + n_fft] ** 2)
31
+ elif mode == "A_weighting":
32
+ spec = np.fft.rfft(np.hanning(n_fft + 1)[:-1] * sound[i: i + n_fft])
33
+ power_spec = np.abs(spec) ** 2
34
+ a_weighted_spec = power_spec * np.power(10, a_weight(fs, n_fft) / 10)
35
+ g = np.sum(a_weighted_spec)
36
+ else:
37
+ raise Exception("Invalid mode {}".format(mode))
38
+ gain.append(g)
39
+
40
+ gain = np.array(gain)
41
+ gain = np.maximum(gain, np.power(10, min_db / 10))
42
+ gain_db = 10 * np.log10(gain)
43
+ return gain_db
44
+
45
+
46
+ def mix(sound1, sound2, r, fs):
47
+ gain1 = np.max(compute_gain(sound1, fs)) # Decibel
48
+ gain2 = np.max(compute_gain(sound2, fs))
49
+ t = 1.0 / (1 + np.power(10, (gain1 - gain2) / 20.) * (1 - r) / r)
50
+ sound = ((sound1 * t + sound2 * (1 - t)) / np.sqrt(t ** 2 + (1 - t) ** 2))
51
  return sound
codeclm/tokenizer/Flow1dVAE/tools/torch_tools.py CHANGED
@@ -1,143 +1,143 @@
1
- import torch
2
- import torchaudio
3
- import random
4
- import itertools
5
- import numpy as np
6
- from tools.mix import mix
7
-
8
-
9
- def normalize_wav(waveform):
10
- waveform = waveform - torch.mean(waveform)
11
- waveform = waveform / (torch.max(torch.abs(waveform)) + 1e-8)
12
- return waveform * 0.5
13
-
14
-
15
- def pad_wav(waveform, segment_length):
16
- waveform_length = len(waveform)
17
-
18
- if segment_length is None or waveform_length == segment_length:
19
- return waveform
20
- elif waveform_length > segment_length:
21
- return waveform[:segment_length]
22
- else:
23
- pad_wav = torch.zeros(segment_length - waveform_length).to(waveform.device)
24
- waveform = torch.cat([waveform, pad_wav])
25
- return waveform
26
-
27
-
28
- def _pad_spec(fbank, target_length=1024):
29
- batch, n_frames, channels = fbank.shape
30
- p = target_length - n_frames
31
- if p > 0:
32
- pad = torch.zeros(batch, p, channels).to(fbank.device)
33
- fbank = torch.cat([fbank, pad], 1)
34
- elif p < 0:
35
- fbank = fbank[:, :target_length, :]
36
-
37
- if channels % 2 != 0:
38
- fbank = fbank[:, :, :-1]
39
-
40
- return fbank
41
-
42
-
43
- def read_wav_file(filename, segment_length):
44
- waveform, sr = torchaudio.load(filename) # Faster!!!
45
- waveform = torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=16000)[0]
46
- try:
47
- waveform = normalize_wav(waveform)
48
- except:
49
- print ("Exception normalizing:", filename)
50
- waveform = torch.ones(160000)
51
- waveform = pad_wav(waveform, segment_length).unsqueeze(0)
52
- waveform = waveform / torch.max(torch.abs(waveform))
53
- waveform = 0.5 * waveform
54
- return waveform
55
-
56
-
57
- def get_mel_from_wav(audio, _stft):
58
- audio = torch.nan_to_num(torch.clip(audio, -1, 1))
59
- audio = torch.autograd.Variable(audio, requires_grad=False)
60
- melspec, log_magnitudes_stft, energy = _stft.mel_spectrogram(audio)
61
- return melspec, log_magnitudes_stft, energy
62
-
63
-
64
- def wav_to_fbank(paths, target_length=1024, fn_STFT=None):
65
- assert fn_STFT is not None
66
-
67
- waveform = torch.cat([read_wav_file(path, target_length * 160) for path in paths], 0) # hop size is 160
68
-
69
- fbank, log_magnitudes_stft, energy = get_mel_from_wav(waveform, fn_STFT)
70
- fbank = fbank.transpose(1, 2)
71
- log_magnitudes_stft = log_magnitudes_stft.transpose(1, 2)
72
-
73
- fbank, log_magnitudes_stft = _pad_spec(fbank, target_length), _pad_spec(
74
- log_magnitudes_stft, target_length
75
- )
76
-
77
- return fbank, log_magnitudes_stft, waveform
78
-
79
- def wav_to_fbank2(waveform, target_length=-1, fn_STFT=None):
80
- assert fn_STFT is not None
81
-
82
- fbank, log_magnitudes_stft, energy = get_mel_from_wav(waveform, fn_STFT)
83
- fbank = fbank.transpose(1, 2)
84
- log_magnitudes_stft = log_magnitudes_stft.transpose(1, 2)
85
- # print(fbank.shape, log_magnitudes_stft.shape)
86
-
87
- if(target_length>0):
88
- fbank, log_magnitudes_stft = _pad_spec(fbank, target_length), _pad_spec(
89
- log_magnitudes_stft, target_length
90
- )
91
-
92
- return fbank, log_magnitudes_stft, waveform
93
-
94
-
95
- def uncapitalize(s):
96
- if s:
97
- return s[:1].lower() + s[1:]
98
- else:
99
- return ""
100
-
101
-
102
- def mix_wavs_and_captions(path1, path2, caption1, caption2, target_length=1024):
103
- sound1 = read_wav_file(path1, target_length * 160)[0].numpy()
104
- sound2 = read_wav_file(path2, target_length * 160)[0].numpy()
105
- mixed_sound = mix(sound1, sound2, 0.5, 16000).reshape(1, -1)
106
- mixed_caption = "{} and {}".format(caption1, uncapitalize(caption2))
107
- return mixed_sound, mixed_caption
108
-
109
-
110
- def augment(paths, texts, num_items=4, target_length=1024):
111
- mixed_sounds, mixed_captions = [], []
112
- combinations = list(itertools.combinations(list(range(len(texts))), 2))
113
- random.shuffle(combinations)
114
- if len(combinations) < num_items:
115
- selected_combinations = combinations
116
- else:
117
- selected_combinations = combinations[:num_items]
118
-
119
- for (i, j) in selected_combinations:
120
- new_sound, new_caption = mix_wavs_and_captions(paths[i], paths[j], texts[i], texts[j], target_length)
121
- mixed_sounds.append(new_sound)
122
- mixed_captions.append(new_caption)
123
-
124
- waveform = torch.tensor(np.concatenate(mixed_sounds, 0))
125
- waveform = waveform / torch.max(torch.abs(waveform))
126
- waveform = 0.5 * waveform
127
-
128
- return waveform, mixed_captions
129
-
130
-
131
- def augment_wav_to_fbank(paths, texts, num_items=4, target_length=1024, fn_STFT=None):
132
- assert fn_STFT is not None
133
-
134
- waveform, captions = augment(paths, texts)
135
- fbank, log_magnitudes_stft, energy = get_mel_from_wav(waveform, fn_STFT)
136
- fbank = fbank.transpose(1, 2)
137
- log_magnitudes_stft = log_magnitudes_stft.transpose(1, 2)
138
-
139
- fbank, log_magnitudes_stft = _pad_spec(fbank, target_length), _pad_spec(
140
- log_magnitudes_stft, target_length
141
- )
142
-
143
  return fbank, log_magnitudes_stft, waveform, captions
 
1
+ import torch
2
+ import torchaudio
3
+ import random
4
+ import itertools
5
+ import numpy as np
6
+ from tools.mix import mix
7
+
8
+
9
+ def normalize_wav(waveform):
10
+ waveform = waveform - torch.mean(waveform)
11
+ waveform = waveform / (torch.max(torch.abs(waveform)) + 1e-8)
12
+ return waveform * 0.5
13
+
14
+
15
+ def pad_wav(waveform, segment_length):
16
+ waveform_length = len(waveform)
17
+
18
+ if segment_length is None or waveform_length == segment_length:
19
+ return waveform
20
+ elif waveform_length > segment_length:
21
+ return waveform[:segment_length]
22
+ else:
23
+ pad_wav = torch.zeros(segment_length - waveform_length).to(waveform.device)
24
+ waveform = torch.cat([waveform, pad_wav])
25
+ return waveform
26
+
27
+
28
+ def _pad_spec(fbank, target_length=1024):
29
+ batch, n_frames, channels = fbank.shape
30
+ p = target_length - n_frames
31
+ if p > 0:
32
+ pad = torch.zeros(batch, p, channels).to(fbank.device)
33
+ fbank = torch.cat([fbank, pad], 1)
34
+ elif p < 0:
35
+ fbank = fbank[:, :target_length, :]
36
+
37
+ if channels % 2 != 0:
38
+ fbank = fbank[:, :, :-1]
39
+
40
+ return fbank
41
+
42
+
43
+ def read_wav_file(filename, segment_length):
44
+ waveform, sr = torchaudio.load(filename) # Faster!!!
45
+ waveform = torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=16000)[0]
46
+ try:
47
+ waveform = normalize_wav(waveform)
48
+ except:
49
+ print ("Exception normalizing:", filename)
50
+ waveform = torch.ones(160000)
51
+ waveform = pad_wav(waveform, segment_length).unsqueeze(0)
52
+ waveform = waveform / torch.max(torch.abs(waveform))
53
+ waveform = 0.5 * waveform
54
+ return waveform
55
+
56
+
57
+ def get_mel_from_wav(audio, _stft):
58
+ audio = torch.nan_to_num(torch.clip(audio, -1, 1))
59
+ audio = torch.autograd.Variable(audio, requires_grad=False)
60
+ melspec, log_magnitudes_stft, energy = _stft.mel_spectrogram(audio)
61
+ return melspec, log_magnitudes_stft, energy
62
+
63
+
64
+ def wav_to_fbank(paths, target_length=1024, fn_STFT=None):
65
+ assert fn_STFT is not None
66
+
67
+ waveform = torch.cat([read_wav_file(path, target_length * 160) for path in paths], 0) # hop size is 160
68
+
69
+ fbank, log_magnitudes_stft, energy = get_mel_from_wav(waveform, fn_STFT)
70
+ fbank = fbank.transpose(1, 2)
71
+ log_magnitudes_stft = log_magnitudes_stft.transpose(1, 2)
72
+
73
+ fbank, log_magnitudes_stft = _pad_spec(fbank, target_length), _pad_spec(
74
+ log_magnitudes_stft, target_length
75
+ )
76
+
77
+ return fbank, log_magnitudes_stft, waveform
78
+
79
+ def wav_to_fbank2(waveform, target_length=-1, fn_STFT=None):
80
+ assert fn_STFT is not None
81
+
82
+ fbank, log_magnitudes_stft, energy = get_mel_from_wav(waveform, fn_STFT)
83
+ fbank = fbank.transpose(1, 2)
84
+ log_magnitudes_stft = log_magnitudes_stft.transpose(1, 2)
85
+ # print(fbank.shape, log_magnitudes_stft.shape)
86
+
87
+ if(target_length>0):
88
+ fbank, log_magnitudes_stft = _pad_spec(fbank, target_length), _pad_spec(
89
+ log_magnitudes_stft, target_length
90
+ )
91
+
92
+ return fbank, log_magnitudes_stft, waveform
93
+
94
+
95
+ def uncapitalize(s):
96
+ if s:
97
+ return s[:1].lower() + s[1:]
98
+ else:
99
+ return ""
100
+
101
+
102
+ def mix_wavs_and_captions(path1, path2, caption1, caption2, target_length=1024):
103
+ sound1 = read_wav_file(path1, target_length * 160)[0].numpy()
104
+ sound2 = read_wav_file(path2, target_length * 160)[0].numpy()
105
+ mixed_sound = mix(sound1, sound2, 0.5, 16000).reshape(1, -1)
106
+ mixed_caption = "{} and {}".format(caption1, uncapitalize(caption2))
107
+ return mixed_sound, mixed_caption
108
+
109
+
110
+ def augment(paths, texts, num_items=4, target_length=1024):
111
+ mixed_sounds, mixed_captions = [], []
112
+ combinations = list(itertools.combinations(list(range(len(texts))), 2))
113
+ random.shuffle(combinations)
114
+ if len(combinations) < num_items:
115
+ selected_combinations = combinations
116
+ else:
117
+ selected_combinations = combinations[:num_items]
118
+
119
+ for (i, j) in selected_combinations:
120
+ new_sound, new_caption = mix_wavs_and_captions(paths[i], paths[j], texts[i], texts[j], target_length)
121
+ mixed_sounds.append(new_sound)
122
+ mixed_captions.append(new_caption)
123
+
124
+ waveform = torch.tensor(np.concatenate(mixed_sounds, 0))
125
+ waveform = waveform / torch.max(torch.abs(waveform))
126
+ waveform = 0.5 * waveform
127
+
128
+ return waveform, mixed_captions
129
+
130
+
131
+ def augment_wav_to_fbank(paths, texts, num_items=4, target_length=1024, fn_STFT=None):
132
+ assert fn_STFT is not None
133
+
134
+ waveform, captions = augment(paths, texts)
135
+ fbank, log_magnitudes_stft, energy = get_mel_from_wav(waveform, fn_STFT)
136
+ fbank = fbank.transpose(1, 2)
137
+ log_magnitudes_stft = log_magnitudes_stft.transpose(1, 2)
138
+
139
+ fbank, log_magnitudes_stft = _pad_spec(fbank, target_length), _pad_spec(
140
+ log_magnitudes_stft, target_length
141
+ )
142
+
143
  return fbank, log_magnitudes_stft, waveform, captions
codeclm/tokenizer/audio_tokenizer.py CHANGED
@@ -208,9 +208,9 @@ class Flow1dVAESeparate(AudioTokenizer):
208
  return codes_vocal, codes_bgm
209
 
210
  @torch.no_grad()
211
- def decode(self, codes: torch.Tensor, prompt_vocal = None, prompt_bgm = None):
212
  wav = self.model.code2sound(codes, prompt_vocal=prompt_vocal, prompt_bgm=prompt_bgm, guidance_scale=1.5,
213
- num_steps=50, disable_progress=False) # [B,N,T] -> [B,T]
214
  return wav[None]
215
 
216
 
 
208
  return codes_vocal, codes_bgm
209
 
210
  @torch.no_grad()
211
+ def decode(self, codes: torch.Tensor, prompt_vocal = None, prompt_bgm = None, chunked=False):
212
  wav = self.model.code2sound(codes, prompt_vocal=prompt_vocal, prompt_bgm=prompt_bgm, guidance_scale=1.5,
213
+ num_steps=50, disable_progress=False, chunked=chunked) # [B,N,T] -> [B,T]
214
  return wav[None]
215
 
216
 
generate_lowmem.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+
4
+ import time
5
+ import json
6
+ import torch
7
+ import torchaudio
8
+ import numpy as np
9
+ from omegaconf import OmegaConf
10
+ from codeclm.models import builders
11
+
12
+ from codeclm.trainer.codec_song_pl import CodecLM_PL
13
+ from codeclm.models import CodecLM
14
+ from third_party.demucs.models.pretrained import get_model_from_yaml
15
+
16
+ auto_prompt_type = ['Pop', 'R&B', 'Dance', 'Jazz', 'Folk', 'Rock', 'Chinese Style', 'Chinese Tradition', 'Metal', 'Reggae', 'Chinese Opera', 'Auto']
17
+
18
+ class Separator:
19
+ def __init__(self, dm_model_path='third_party/demucs/ckpt/htdemucs.pth', dm_config_path='third_party/demucs/ckpt/htdemucs.yaml', gpu_id=0) -> None:
20
+ if torch.cuda.is_available() and gpu_id < torch.cuda.device_count():
21
+ self.device = torch.device(f"cuda:{gpu_id}")
22
+ else:
23
+ self.device = torch.device("cpu")
24
+ self.demucs_model = self.init_demucs_model(dm_model_path, dm_config_path)
25
+
26
+ def init_demucs_model(self, model_path, config_path):
27
+ model = get_model_from_yaml(config_path, model_path)
28
+ model.to(self.device)
29
+ model.eval()
30
+ return model
31
+
32
+ def load_audio(self, f):
33
+ a, fs = torchaudio.load(f)
34
+ if (fs != 48000):
35
+ a = torchaudio.functional.resample(a, fs, 48000)
36
+ if a.shape[-1] >= 48000*10:
37
+ a = a[..., :48000*10]
38
+ else:
39
+ a = torch.cat([a, a], -1)
40
+ return a[:, 0:48000*10]
41
+
42
+ def run(self, audio_path, output_dir='tmp', ext=".flac"):
43
+ os.makedirs(output_dir, exist_ok=True)
44
+ name, _ = os.path.splitext(os.path.split(audio_path)[-1])
45
+ output_paths = []
46
+
47
+ for stem in self.demucs_model.sources:
48
+ output_path = os.path.join(output_dir, f"{name}_{stem}{ext}")
49
+ if os.path.exists(output_path):
50
+ output_paths.append(output_path)
51
+ if len(output_paths) == 1: # 4
52
+ vocal_path = output_paths[0]
53
+ else:
54
+ drums_path, bass_path, other_path, vocal_path = self.demucs_model.separate(audio_path, output_dir, device=self.device)
55
+ for path in [drums_path, bass_path, other_path]:
56
+ os.remove(path)
57
+ full_audio = self.load_audio(audio_path)
58
+ vocal_audio = self.load_audio(vocal_path)
59
+ bgm_audio = full_audio - vocal_audio
60
+ return full_audio, vocal_audio, bgm_audio
61
+
62
+
63
+
64
+ if __name__ == "__main__":
65
+ torch.backends.cudnn.enabled = False
66
+ OmegaConf.register_new_resolver("eval", lambda x: eval(x))
67
+ OmegaConf.register_new_resolver("concat", lambda *x: [xxx for xx in x for xxx in xx])
68
+ OmegaConf.register_new_resolver("get_fname", lambda: os.path.splitext(os.path.basename(sys.argv[1]))[0])
69
+ OmegaConf.register_new_resolver("load_yaml", lambda x: list(OmegaConf.load(x)))
70
+ np.random.seed(int(time.time()))
71
+ ckpt_path = sys.argv[1]
72
+ input_jsonl = sys.argv[2]
73
+ save_dir = sys.argv[3]
74
+ cfg_path = os.path.join(ckpt_path, 'config.yaml')
75
+ ckpt_path = os.path.join(ckpt_path, 'model.pt')
76
+ cfg = OmegaConf.load(cfg_path)
77
+ cfg.mode = 'inference'
78
+ max_duration = cfg.max_dur
79
+
80
+ separator = Separator()
81
+ auto_prompt = torch.load('ckpt/prompt.pt')
82
+ audio_tokenizer = builders.get_audio_tokenizer_model(cfg.audio_tokenizer_checkpoint, cfg)
83
+ if "audio_tokenizer_checkpoint_sep" in cfg.keys():
84
+ seperate_tokenizer = builders.get_audio_tokenizer_model(cfg.audio_tokenizer_checkpoint_sep, cfg)
85
+ else:
86
+ seperate_tokenizer = None
87
+ audio_tokenizer = audio_tokenizer.eval().cuda()
88
+ if seperate_tokenizer is not None:
89
+ seperate_tokenizer = seperate_tokenizer.eval().cuda()
90
+
91
+ merge_prompt = [item for sublist in auto_prompt.values() for item in sublist]
92
+ with open(input_jsonl, "r") as fp:
93
+ lines = fp.readlines()
94
+ new_items = []
95
+ for line in lines:
96
+ item = json.loads(line)
97
+ target_wav_name = f"{save_dir}/audios/{item['idx']}.flac"
98
+ # get prompt audio
99
+ if "prompt_audio_path" in item:
100
+ assert os.path.exists(item['prompt_audio_path']), f"prompt_audio_path {item['prompt_audio_path']} not found"
101
+ assert 'auto_prompt_audio_type' not in item, f"auto_prompt_audio_type and prompt_audio_path cannot be used together"
102
+ pmt_wav, vocal_wav, bgm_wav = separator.run(item['prompt_audio_path'])
103
+ item['raw_pmt_wav'] = pmt_wav
104
+ item['raw_vocal_wav'] = vocal_wav
105
+ item['raw_bgm_wav'] = bgm_wav
106
+ if pmt_wav.dim() == 2:
107
+ pmt_wav = pmt_wav[None]
108
+ if pmt_wav.dim() != 3:
109
+ raise ValueError("Melody wavs should have a shape [B, C, T].")
110
+ pmt_wav = list(pmt_wav)
111
+ if vocal_wav.dim() == 2:
112
+ vocal_wav = vocal_wav[None]
113
+ if vocal_wav.dim() != 3:
114
+ raise ValueError("Vocal wavs should have a shape [B, C, T].")
115
+ vocal_wav = list(vocal_wav)
116
+ if bgm_wav.dim() == 2:
117
+ bgm_wav = bgm_wav[None]
118
+ if bgm_wav.dim() != 3:
119
+ raise ValueError("BGM wavs should have a shape [B, C, T].")
120
+ bgm_wav = list(bgm_wav)
121
+ if type(pmt_wav) == list:
122
+ pmt_wav = torch.stack(pmt_wav, dim=0)
123
+ if type(vocal_wav) == list:
124
+ vocal_wav = torch.stack(vocal_wav, dim=0)
125
+ if type(bgm_wav) == list:
126
+ bgm_wav = torch.stack(bgm_wav, dim=0)
127
+ pmt_wav = pmt_wav.cuda()
128
+ vocal_wav = vocal_wav.cuda()
129
+ bgm_wav = bgm_wav.cuda()
130
+ pmt_wav, _ = audio_tokenizer.encode(pmt_wav)
131
+ vocal_wav, bgm_wav = seperate_tokenizer.encode(vocal_wav, bgm_wav)
132
+ melody_is_wav = False
133
+ elif "auto_prompt_audio_type" in item:
134
+ assert item["auto_prompt_audio_type"] in auto_prompt_type, f"auto_prompt_audio_type {item['auto_prompt_audio_type']} not found"
135
+ if item["auto_prompt_audio_type"] == "Auto":
136
+ prompt_token = merge_prompt[np.random.randint(0, len(merge_prompt))]
137
+ else:
138
+ prompt_token = auto_prompt[item["auto_prompt_audio_type"]][np.random.randint(0, len(auto_prompt[item["auto_prompt_audio_type"]]))]
139
+ pmt_wav = prompt_token[:,[0],:]
140
+ vocal_wav = prompt_token[:,[1],:]
141
+ bgm_wav = prompt_token[:,[2],:]
142
+ melody_is_wav = False
143
+ else:
144
+ pmt_wav = None
145
+ vocal_wav = None
146
+ bgm_wav = None
147
+ melody_is_wav = True
148
+ item['pmt_wav'] = pmt_wav
149
+ item['vocal_wav'] = vocal_wav
150
+ item['bgm_wav'] = bgm_wav
151
+ item['melody_is_wav'] = melody_is_wav
152
+ item["idx"] = f"{item['idx']}"
153
+ item["wav_path"] = target_wav_name
154
+ new_items.append(item)
155
+
156
+ del audio_tokenizer
157
+ del seperate_tokenizer
158
+ del separator
159
+
160
+ # Define model or load pretrained model
161
+ model_light = CodecLM_PL(cfg, ckpt_path)
162
+ model_light = model_light.eval()
163
+ model_light.audiolm.cfg = cfg
164
+ model = CodecLM(name = "tmp",
165
+ lm = model_light.audiolm,
166
+ audiotokenizer = None,
167
+ max_duration = max_duration,
168
+ seperate_tokenizer = None,
169
+ )
170
+ del model_light
171
+ model.lm = model.lm.cuda().to(torch.float16)
172
+
173
+ cfg_coef = 1.5 #25
174
+ temp = 0.9
175
+ top_k = 50
176
+ top_p = 0.0
177
+ record_tokens = True
178
+ record_window = 50
179
+
180
+ model.set_generation_params(duration=max_duration, extend_stride=5, temperature=temp, cfg_coef=cfg_coef,
181
+ top_k=top_k, top_p=top_p, record_tokens=record_tokens, record_window=record_window)
182
+ os.makedirs(save_dir, exist_ok=True)
183
+ os.makedirs(save_dir + "/audios", exist_ok=True)
184
+ os.makedirs(save_dir + "/jsonl", exist_ok=True)
185
+
186
+
187
+ for item in new_items:
188
+ lyric = item["gt_lyric"]
189
+ descriptions = item["descriptions"] if "descriptions" in item else None
190
+ pmt_wav = item['pmt_wav']
191
+ vocal_wav = item['vocal_wav']
192
+ bgm_wav = item['bgm_wav']
193
+ melody_is_wav = item['melody_is_wav']
194
+
195
+ generate_inp = {
196
+ 'lyrics': [lyric.replace(" ", " ")],
197
+ 'descriptions': [descriptions],
198
+ 'melody_wavs': pmt_wav,
199
+ 'vocal_wavs': vocal_wav,
200
+ 'bgm_wavs': bgm_wav,
201
+ 'melody_is_wav': melody_is_wav,
202
+ }
203
+ with torch.autocast(device_type="cuda", dtype=torch.float16):
204
+ tokens = model.generate(**generate_inp, return_tokens=True)
205
+ item['tokens'] = tokens
206
+
207
+ del model
208
+ torch.cuda.empty_cache()
209
+
210
+
211
+ seperate_tokenizer = builders.get_audio_tokenizer_model(cfg.audio_tokenizer_checkpoint_sep, cfg)
212
+ seperate_tokenizer = seperate_tokenizer.eval().cuda()
213
+
214
+ model = CodecLM(name = "tmp",
215
+ lm = None,
216
+ audiotokenizer = None,
217
+ max_duration = max_duration,
218
+ seperate_tokenizer = seperate_tokenizer,
219
+ )
220
+ for item in new_items:
221
+ with torch.no_grad():
222
+ if 'raw_pmt_wav' in item:
223
+ wav_seperate = model.generate_audio(item['tokens'], item['raw_pmt_wav'], item['raw_vocal_wav'], item['raw_bgm_wav'], chunked=True)
224
+ del item['raw_pmt_wav']
225
+ del item['raw_vocal_wav']
226
+ del item['raw_bgm_wav']
227
+ else:
228
+ wav_seperate = model.generate_audio(item['tokens'], chunked=True)
229
+ torchaudio.save(item['wav_path'], wav_seperate[0].cpu().float(), cfg.sample_rate)
230
+ del item['tokens']
231
+ del item['pmt_wav']
232
+ del item['vocal_wav']
233
+ del item['bgm_wav']
234
+ del item['melody_is_wav']
235
+
236
+ torch.cuda.empty_cache()
237
+ src_jsonl_name = os.path.split(input_jsonl)[-1]
238
+ with open(f"{save_dir}/jsonl/{src_jsonl_name}.jsonl", "w", encoding='utf-8') as fw:
239
+ for item in new_items:
240
+ fw.writelines(json.dumps(item, ensure_ascii=False)+"\n")
generate_lowmem.sh ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ export USER=root
2
+ export PYTHONDONTWRITEBYTECODE=1
3
+ export TRANSFORMERS_CACHE="$(pwd)/third_party/hub"
4
+ export NCCL_HOME=/usr/local/tccl
5
+ export PYTHONPATH="$(pwd)/codeclm/tokenizer/":"$(pwd)":"$(pwd)/codeclm/tokenizer/Flow1dVAE/":"$(pwd)/codeclm/tokenizer/":$PYTHONPATH
6
+
7
+ CKPT_PATH=$1
8
+ JSONL=$2
9
+ SAVE_DIR=$3
10
+ python3 generate_lowmem.py $CKPT_PATH $JSONL $SAVE_DIR
requirements.txt CHANGED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ alias-free-torch>=0.0.6
2
+ descript-audio-codec>=1.0.0
3
+ diffusers==0.27.2
4
+ einops>=0.8.1
5
+ einops-exts==0.0.4
6
+ flashy>=0.0.2
7
+ huggingface-hub==0.25.2
8
+ julius>=0.2.7
9
+ k-diffusion==0.1.1
10
+ kaldiio>=2.18.1
11
+ lameenc>=1.8.1
12
+ librosa>=0.11.0
13
+ lightning>=2.5.2
14
+ ninja>=1.11.1.4
15
+ nnAudio>=0.3.3
16
+ openunmix>=1.3.0
17
+ peft==0.10.0
18
+ torch==2.6.0
19
+ torchaudio==2.6.0
20
+ torchvision==0.21.0
21
+ transformers==4.37.2
22
+ vector-quantize-pytorch>=1.22.17
23
+ wheel>=0.45.1
24
+ x-transformers>=2.3.25
requirements_nodeps.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ fairseq==0.12.2
2
+ antlr4-python3-runtime==4.8
3
+ bitarray==3.4.3
4
+ cffi==1.17.1
5
+ colorama==0.4.6
6
+ cython==3.1.2
7
+ hydra-core==1.0.7
8
+ lxml==5.4.0
9
+ omegaconf==2.2.0
10
+ portalocker==3.2.0
11
+ pycparser==2.22
12
+ sacrebleu==2.5.1
13
+ tabulate==0.9.0
sample/lyrics.jsonl CHANGED
@@ -1,4 +1,4 @@
1
  {"idx": "sample_01_autoprompt", "gt_lyric": "[intro-short] ; [verse] 雪花舞动在无尽的天际.情缘如同雪花般轻轻逝去.希望与真挚.永不磨灭.你的忧虑.随风而逝 ; [chorus] 我怀抱着守护这片梦境.在这世界中寻找爱与虚幻.苦辣酸甜.我们一起品尝.在雪的光芒中.紧紧相拥 ; [inst-short] ; [verse] 雪花再次在风中飘扬.情愿如同雪花般消失无踪.希望与真挚.永不消失.在痛苦与喧嚣中.你找到解脱 ; [chorus] 我环绕着守护这片梦境.在这世界中感受爱与虚假.苦辣酸甜.我们一起分享.在白银的光芒中.我们同在 ; [outro-short]", "auto_prompt_audio_type": "Auto"}
2
  {"idx": "sample_01_noprompt", "gt_lyric": "[intro-short] ; [verse] 雪花舞动在无尽的天际.情缘如同雪花般轻轻逝去.希望与真挚.永不磨灭.你的忧虑.随风而逝 ; [chorus] 我怀抱着守护这片梦境.在这世界中寻找爱与虚幻.苦辣酸甜.我们一起品尝.在雪的光芒中.紧紧相拥 ; [inst-short] ; [verse] 雪花再次在风中飘扬.情愿如同雪花般消失无踪.希望与真挚.永不消失.在痛苦与喧嚣中.你找到解脱 ; [chorus] 我环绕着守护这片梦境.在这世界中感受爱与虚假.苦辣酸甜.我们一起分享.在白银的光芒中.我们同在 ; [outro-short]"}
3
  {"idx": "sample_01_textprompt", "descriptions": "female, dark, pop, sad, piano and drums, the bpm is 125.", "gt_lyric": "[intro-short] ; [verse] 雪花舞动在无尽的天际.情缘如同雪花般轻轻逝去.希望与真挚.永不磨灭.你的忧虑.随风而逝 ; [chorus] 我怀抱着守护这片梦境.在这世界中寻找爱与虚幻.苦辣酸甜.我们一起品尝.在雪的光芒中.紧紧相拥 ; [inst-short] ; [verse] 雪花再次在风中飘扬.情愿如同雪花般消失无踪.希望与真挚.永不消失.在痛苦与喧嚣中.你找到解脱 ; [chorus] 我环绕着守护这片梦境.在这世界中感受爱与虚假.苦辣酸甜.我们一起分享.在白银的光芒中.我们同在 ; [outro-short]"}
4
- {"idx": "sample_01_audioprompt", "gt_lyric": "[intro-short] ; [verse] 雪花舞动在无尽的天际.情缘如同雪花般轻轻逝去.希望与真挚.永不磨灭.你的忧虑.随风而逝 ; [chorus] 我怀抱着守护这片梦境.在这世界中寻找爱与虚幻.苦辣酸甜.我们一起品尝.在雪的光芒中.紧紧相拥 ; [inst-short] ; [verse] 雪花再次在风中飘扬.情愿如同雪花般消失无踪.希望与真挚.永不消失.在痛苦与喧嚣中.你找到解脱 ; [chorus] 我环绕着守护这片梦境.在这世界中感受爱与虚假.苦辣酸甜.我们一起分享.在白银的光芒中.我们同在 ; [outro-short]", "prompt_audio_path": "sample/sample_prompt_audio.wav"}
 
1
  {"idx": "sample_01_autoprompt", "gt_lyric": "[intro-short] ; [verse] 雪花舞动在无尽的天际.情缘如同雪花般轻轻逝去.希望与真挚.永不磨灭.你的忧虑.随风而逝 ; [chorus] 我怀抱着守护这片梦境.在这世界中寻找爱与虚幻.苦辣酸甜.我们一起品尝.在雪的光芒中.紧紧相拥 ; [inst-short] ; [verse] 雪花再次在风中飘扬.情愿如同雪花般消失无踪.希望与真挚.永不消失.在痛苦与喧嚣中.你找到解脱 ; [chorus] 我环绕着守护这片梦境.在这世界中感受爱与虚假.苦辣酸甜.我们一起分享.在白银的光芒中.我们同在 ; [outro-short]", "auto_prompt_audio_type": "Auto"}
2
  {"idx": "sample_01_noprompt", "gt_lyric": "[intro-short] ; [verse] 雪花舞动在无尽的天际.情缘如同雪花般轻轻逝去.希望与真挚.永不磨灭.你的忧虑.随风而逝 ; [chorus] 我怀抱着守护这片梦境.在这世界中寻找爱与虚幻.苦辣酸甜.我们一起品尝.在雪的光芒中.紧紧相拥 ; [inst-short] ; [verse] 雪花再次在风中飘扬.情愿如同雪花般消失无踪.希望与真挚.永不消失.在痛苦与喧嚣中.你找到解脱 ; [chorus] 我环绕着守护这片梦境.在这世界中感受爱与虚假.苦辣酸甜.我们一起分享.在白银的光芒中.我们同在 ; [outro-short]"}
3
  {"idx": "sample_01_textprompt", "descriptions": "female, dark, pop, sad, piano and drums, the bpm is 125.", "gt_lyric": "[intro-short] ; [verse] 雪花舞动在无尽的天际.情缘如同雪花般轻轻逝去.希望与真挚.永不磨灭.你的忧虑.随风而逝 ; [chorus] 我怀抱着守护这片梦境.在这世界中寻找爱与虚幻.苦辣酸甜.我们一起品尝.在雪的光芒中.紧紧相拥 ; [inst-short] ; [verse] 雪花再次在风中飘扬.情愿如同雪花般消失无踪.希望与真挚.永不消失.在痛苦与喧嚣中.你找到解脱 ; [chorus] 我环绕着守护这片梦境.在这世界中感受爱与虚假.苦辣酸甜.我们一起分享.在白银的光芒中.我们同在 ; [outro-short]"}
4
+ {"idx": "sample_01_audioprompt", "gt_lyric": "[intro-short] ; [verse] 雪花舞动在无尽的天际.情缘如同雪花般轻轻逝去.希望与真挚.永不磨灭.你的忧虑.随风而逝 ; [chorus] 我怀抱着守护这片梦境.在这世界中寻找爱与虚幻.苦辣酸甜.我们一起品尝.在雪的光芒中.紧紧相拥 ; [inst-short] ; [verse] 雪花再次在风中飘扬.情愿如同雪花般消失无踪.希望与真挚.永不消失.在痛苦与喧嚣中.你找到解脱 ; [chorus] 我环绕着守护这片梦境.在这世界中感受爱与虚假.苦辣酸甜.我们一起分享.在白银的光芒中.我们同在 ; [outro-short]", "prompt_audio_path": "input/sample_prompt_audio.wav"}
tools/gradio/app.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import gradio as gr
3
+ import json
4
+ from datetime import datetime
5
+ import yaml
6
+ import time
7
+ import re
8
+ import os.path as op
9
+ from levo_inference_lowmem import LeVoInference
10
+
11
+ EXAMPLE_LYRICS = """
12
+ [intro-short]
13
+
14
+ [verse]
15
+ 夜晚的街灯闪烁
16
+ 我漫步在熟悉的角落
17
+ 回忆像潮水般涌来
18
+ 你的笑容如此清晰
19
+ 在心头无法抹去
20
+ 那些曾经的甜蜜
21
+ 如今只剩我独自回忆
22
+
23
+ [verse]
24
+ 手机屏幕亮起
25
+ 是你发来的消息
26
+ 简单的几个字
27
+ 却让我泪流满面
28
+ 曾经的拥抱温暖
29
+ 如今却变得遥远
30
+ 我多想回到从前
31
+ 重新拥有你的陪伴
32
+
33
+ [chorus]
34
+ 回忆的温度还在
35
+ 你却已不在
36
+ 我的心被爱填满
37
+ 却又被思念刺痛
38
+ 音乐的节奏奏响
39
+ 我的心却在流浪
40
+ 没有你的日子
41
+ 我该如何继续向前
42
+
43
+ [outro-short]
44
+ """.strip()
45
+
46
+ APP_DIR = op.dirname(op.dirname(op.dirname(op.abspath(__file__))))
47
+ MODEL = LeVoInference(sys.argv[1])
48
+ with open(op.join(APP_DIR, 'conf/vocab.yaml'), 'r', encoding='utf-8') as file:
49
+ STRUCTS = yaml.safe_load(file)
50
+
51
+
52
+ def generate_song(lyric, description=None, prompt_audio=None, genre=None, cfg_coef=None, temperature=None, top_k=None, progress=gr.Progress(track_tqdm=True)):
53
+ global MODEL
54
+ global STRUCTS
55
+ params = {'cfg_coef':cfg_coef, 'temperature':temperature, 'top_k':top_k}
56
+ params = {k:v for k,v in params.items() if v is not None}
57
+ vocal_structs = ['[verse]', '[chorus]', '[bridge]']
58
+ sample_rate = MODEL.cfg.sample_rate
59
+
60
+ # format lyric
61
+ lyric = lyric.replace("[intro]", "[intro-short]").replace("[inst]", "[inst-short]").replace("[outro]", "[outro-short]")
62
+ paragraphs = [p.strip() for p in lyric.strip().split('\n\n') if p.strip()]
63
+ if len(paragraphs) < 1:
64
+ return None, json.dumps("Lyrics can not be left blank")
65
+ paragraphs_norm = []
66
+ vocal_flag = False
67
+ for para in paragraphs:
68
+ lines = para.splitlines()
69
+ struct_tag = lines[0].strip().lower()
70
+ if struct_tag not in STRUCTS:
71
+ return None, json.dumps(f"Segments should start with a structure tag in {STRUCTS}")
72
+ if struct_tag in vocal_structs:
73
+ vocal_flag = True
74
+ if len(lines) < 2 or not [line.strip() for line in lines[1:] if line.strip()]:
75
+ return None, json.dumps("The following segments require lyrics: [verse], [chorus], [bridge]")
76
+ else:
77
+ new_para_list = []
78
+ for line in lines[1:]:
79
+ new_para_list.append(re.sub(r"[^\w\s\[\]\-\u4e00-\u9fff\u3040-\u309f\u30a0-\u30ff\uac00-\ud7af\u00c0-\u017f]", "", line))
80
+ new_para_str = f"{struct_tag} {'.'.join(new_para_list)}"
81
+ else:
82
+ if len(lines) > 1:
83
+ return None, json.dumps("The following segments should not contain lyrics: [intro], [intro-short], [intro-medium], [inst], [inst-short], [inst-medium], [outro], [outro-short], [outro-medium]")
84
+ else:
85
+ new_para_str = struct_tag
86
+ paragraphs_norm.append(new_para_str)
87
+ if not vocal_flag:
88
+ return None, json.dumps(f"The lyric must contain at least one of the following structures: {vocal_structs}")
89
+ lyric_norm = " ; ".join(paragraphs_norm)
90
+
91
+ # format prompt
92
+ if prompt_audio is not None:
93
+ genre = None
94
+ description = None
95
+ elif description is not None and description != "":
96
+ genre = None
97
+
98
+ progress(0.0, "Start Generation")
99
+ start = time.time()
100
+
101
+ audio_data = MODEL(lyric_norm, description, prompt_audio, genre, op.join(APP_DIR, "ckpt/prompt.pt"), params).cpu().permute(1, 0).float().numpy()
102
+
103
+ end = time.time()
104
+
105
+ # 创建输入配置的JSON
106
+ input_config = {
107
+ "lyric": lyric_norm,
108
+ "genre": genre,
109
+ "prompt_audio": prompt_audio,
110
+ "description": description,
111
+ "params": params,
112
+ "inference_duration": end - start,
113
+ "timestamp": datetime.now().isoformat(),
114
+ }
115
+
116
+ return (sample_rate, audio_data), json.dumps(input_config, indent=2)
117
+
118
+
119
+ # 创建Gradio界面
120
+ with gr.Blocks(title="SongGeneration Demo Space") as demo:
121
+ gr.Markdown("# 🎵 SongGeneration Demo Space")
122
+ gr.Markdown("Demo interface for the song generation model. Provide a lyrics, and optionally an audio or text prompt, to generate a custom song.")
123
+
124
+ with gr.Row():
125
+ with gr.Column():
126
+ lyric = gr.Textbox(
127
+ label="Lyrics",
128
+ lines=5,
129
+ max_lines=15,
130
+ value=EXAMPLE_LYRICS,
131
+ info="Each paragraph represents a segment starting with a structure tag and ending with a blank line, each line is a sentence without punctuation, segments [intro], [inst], [outro] should not contain lyrics, while [verse], [chorus], and [bridge] require lyrics.",
132
+ placeholder="""Lyric Format
133
+ '''
134
+ [structure tag]
135
+ lyrics
136
+
137
+ [structure tag]
138
+ lyrics
139
+ '''
140
+ 1. One paragraph represents one segments, starting with a structure tag and ending with a blank line
141
+ 2. One line represents one sentence, punctuation is not recommended inside the sentence
142
+ 3. The following segments should not contain lyrics: [intro-short], [intro-medium], [inst-short], [inst-medium], [outro-short], [outro-medium]
143
+ 4. The following segments require lyrics: [verse], [chorus], [bridge]
144
+ """
145
+ )
146
+
147
+ with gr.Tabs(elem_id="extra-tabs"):
148
+ with gr.Tab("Genre Select"):
149
+ genre = gr.Radio(
150
+ choices=["Pop", "R&B", "Dance", "Jazz", "Folk", "Rock", "Chinese Style", "Chinese Tradition", "Metal", "Reggae", "Chinese Opera", "Auto"],
151
+ label="Genre Select(Optional)",
152
+ value="Pop",
153
+ interactive=True,
154
+ elem_id="single-select-radio"
155
+ )
156
+ with gr.Tab("Audio Prompt"):
157
+ prompt_audio = gr.Audio(
158
+ label="Prompt Audio (Optional)",
159
+ type="filepath",
160
+ elem_id="audio-prompt"
161
+ )
162
+ with gr.Tab("Text Prompt"):
163
+ gr.Markdown("For detailed usage, please refer to [here](https://github.com/tencent-ailab/SongGeneration?tab=readme-ov-file#-description-input-format)")
164
+ description = gr.Textbox(
165
+ label="Song Description (Optional)",
166
+ info="Describe the gender, timbre, genre, emotion, instrument and bpm of the song. Only English is supported currently.​",
167
+ placeholder="female, dark, pop, sad, piano and drums, the bpm is 125.",
168
+ lines=1,
169
+ max_lines=2
170
+ )
171
+
172
+ with gr.Accordion("Advanced Config", open=False):
173
+ cfg_coef = gr.Slider(
174
+ label="CFG Coefficient",
175
+ minimum=0.1,
176
+ maximum=3.0,
177
+ step=0.1,
178
+ value=1.5,
179
+ interactive=True,
180
+ elem_id="cfg-coef",
181
+ )
182
+ temperature = gr.Slider(
183
+ label="Temperature",
184
+ minimum=0.1,
185
+ maximum=2.0,
186
+ step=0.1,
187
+ value=0.9,
188
+ interactive=True,
189
+ elem_id="temperature",
190
+ )
191
+ top_k = gr.Slider(
192
+ label="Top-K",
193
+ minimum=1,
194
+ maximum=100,
195
+ step=1,
196
+ value=50,
197
+ interactive=True,
198
+ elem_id="top_k",
199
+ )
200
+ generate_btn = gr.Button("Generate Song", variant="primary")
201
+
202
+ with gr.Column():
203
+ output_audio = gr.Audio(label="Generated Song", type="numpy")
204
+ output_json = gr.JSON(label="Generated Info")
205
+
206
+ # # 示例按钮
207
+ # examples = gr.Examples(
208
+ # examples=[
209
+ # ["male, bright, rock, happy, electric guitar and drums, the bpm is 150."],
210
+ # ["female, warm, jazz, romantic, synthesizer and piano, the bpm is 100."]
211
+ # ],
212
+ # inputs=[description],
213
+ # label="Text Prompt examples"
214
+ # )
215
+
216
+ # examples = gr.Examples(
217
+ # examples=[
218
+ # "[intro-medium]\n\n[verse]\n在这个疯狂的世界里\n谁不渴望一点改变\n在爱情面前\n我们都显得那么不安全\n你紧紧抱着我\n告诉我再靠近一点\n别让这璀璨的夜晚白白浪费\n我那迷茫的眼睛\n看不见未来的路\n在情感消散之前\n我们对爱的渴望永不熄灭\n你给我留下一句誓言\n想知道我们的爱是否能持续到永远\n[chorus]\n\n约定在那最后的夜晚\n不管命运如何摆布\n我们的心是否依然如初\n我会穿上红衬衫\n带着摇滚的激情\n回到我们初遇的地方\n约定在那最后的夜晚\n就算全世界都变了样\n我依然坚守诺言\n铭记这一天\n你永远是我心中的爱恋\n\n[outro-medium]\n",
219
+ # "[intro-short]\n\n[verse]\nThrough emerald canyons where fireflies dwell\nCerulean berries kiss morning's first swell\nCrystalline dew crowns each Vitamin Dawn's confection dissolves slowly on me\nAmbrosia breezes through honeycomb vines\nNature's own candy in Fibonacci lines\n[chorus] Blueberry fruit so sweet\n takes you higher\n can't be beat\n In your lungs\n it starts to swell\n You're under its spell\n [verse] Resin of sunlight in candied retreat\nMarmalade moonbeams melt under bare feet\nNectar spirals bloom chloroplast champagne\nPhotosynthesis sings through my veins\nChlorophyll rhythms pulse warm in my blood\nThe forest's green pharmacy floods every bud[chorus] Blueberry fruit so sweet\n takes you higher\n can't be beat\n In your lungs\n it starts to swell\n You're under its spell\n feel the buzz\n ride the wave\n Limey me\n blueberry\n your mind's enslaved\n In the haze\n lose all time\n floating free\n feeling fine\n Blueberry\n fruit so sweet\n takes you higher\n can't be beat\n In your lungs\n it starts to swell\n cry\n You're under its spell\n\n[outro-short]\n",
220
+ # ],
221
+ # inputs=[lyric],
222
+ # label="Lyrics examples",
223
+ # )
224
+
225
+ # 生成按钮点击事件
226
+ generate_btn.click(
227
+ fn=generate_song,
228
+ inputs=[lyric, description, prompt_audio, genre, cfg_coef, temperature, top_k],
229
+ outputs=[output_audio, output_json]
230
+ )
231
+
232
+
233
+ # 启动应用
234
+ if __name__ == "__main__":
235
+ demo.launch(server_name="0.0.0.0", server_port=8081)
236
+
tools/gradio/levo_inference.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+
4
+ import torch
5
+
6
+ import json
7
+ import numpy as np
8
+ from omegaconf import OmegaConf
9
+
10
+ from codeclm.trainer.codec_song_pl import CodecLM_PL
11
+ from codeclm.models import CodecLM
12
+
13
+ from separator import Separator
14
+
15
+
16
+ class LeVoInference(torch.nn.Module):
17
+ def __init__(self, ckpt_path):
18
+ super().__init__()
19
+
20
+ torch.backends.cudnn.enabled = False
21
+ OmegaConf.register_new_resolver("eval", lambda x: eval(x))
22
+ OmegaConf.register_new_resolver("concat", lambda *x: [xxx for xx in x for xxx in xx])
23
+ OmegaConf.register_new_resolver("get_fname", lambda: 'default')
24
+ OmegaConf.register_new_resolver("load_yaml", lambda x: list(OmegaConf.load(x)))
25
+
26
+ cfg_path = os.path.join(ckpt_path, 'config.yaml')
27
+ pt_path = os.path.join(ckpt_path, 'model.pt')
28
+
29
+ self.cfg = OmegaConf.load(cfg_path)
30
+ self.cfg.mode = 'inference'
31
+ self.max_duration = self.cfg.max_dur
32
+
33
+ # Define model or load pretrained model
34
+ model_light = CodecLM_PL(self.cfg, pt_path)
35
+
36
+ model_light = model_light.eval().cuda()
37
+ model_light.audiolm.cfg = self.cfg
38
+
39
+ self.model_lm = model_light.audiolm
40
+ self.model_audio_tokenizer = model_light.audio_tokenizer
41
+ self.model_seperate_tokenizer = model_light.seperate_tokenizer
42
+
43
+ self.model = CodecLM(name = "tmp",
44
+ lm = self.model_lm,
45
+ audiotokenizer = self.model_audio_tokenizer,
46
+ max_duration = self.max_duration,
47
+ seperate_tokenizer = self.model_seperate_tokenizer,
48
+ )
49
+ self.separator = Separator()
50
+
51
+
52
+ self.default_params = dict(
53
+ cfg_coef = 1.5,
54
+ temperature = 1.0,
55
+ top_k = 50,
56
+ top_p = 0.0,
57
+ record_tokens = True,
58
+ record_window = 50,
59
+ extend_stride = 5,
60
+ duration = self.max_duration,
61
+ )
62
+
63
+ self.model.set_generation_params(**self.default_params)
64
+
65
+ def forward(self, lyric: str, description: str = None, prompt_audio_path: os.PathLike = None, genre: str = None, auto_prompt_path: os.PathLike = None, params = dict()):
66
+ params = {**self.default_params, **params}
67
+ self.model.set_generation_params(**params)
68
+
69
+ if prompt_audio_path is not None and os.path.exists(prompt_audio_path):
70
+ pmt_wav, vocal_wav, bgm_wav = self.separator.run(prompt_audio_path)
71
+ melody_is_wav = True
72
+ elif genre is not None and auto_prompt_path is not None:
73
+ auto_prompt = torch.load(auto_prompt_path)
74
+ merge_prompt = [item for sublist in auto_prompt.values() for item in sublist]
75
+ if genre == "Auto":
76
+ prompt_token = merge_prompt[np.random.randint(0, len(merge_prompt))]
77
+ else:
78
+ prompt_token = auto_prompt[genre][np.random.randint(0, len(auto_prompt[genre]))]
79
+ pmt_wav = prompt_token[:,[0],:]
80
+ vocal_wav = prompt_token[:,[1],:]
81
+ bgm_wav = prompt_token[:,[2],:]
82
+ melody_is_wav = False
83
+ else:
84
+ pmt_wav = None
85
+ vocal_wav = None
86
+ bgm_wav = None
87
+ melody_is_wav = True
88
+
89
+ generate_inp = {
90
+ 'lyrics': [lyric.replace(" ", " ")],
91
+ 'descriptions': [description],
92
+ 'melody_wavs': pmt_wav,
93
+ 'vocal_wavs': vocal_wav,
94
+ 'bgm_wavs': bgm_wav,
95
+ 'melody_is_wav': melody_is_wav,
96
+ }
97
+
98
+ with torch.autocast(device_type="cuda", dtype=torch.float16):
99
+ tokens = self.model.generate(**generate_inp, return_tokens=True)
100
+
101
+ if tokens.shape[-1] > 3000:
102
+ tokens = tokens[..., :3000]
103
+
104
+ with torch.no_grad():
105
+ if melody_is_wav:
106
+ wav_seperate = self.model.generate_audio(tokens, pmt_wav, vocal_wav, bgm_wav)
107
+ else:
108
+ wav_seperate = self.model.generate_audio(tokens)
109
+
110
+ return wav_seperate[0]
tools/gradio/levo_inference_lowmem.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+
4
+ import torch
5
+
6
+ import json
7
+ import numpy as np
8
+ from omegaconf import OmegaConf
9
+
10
+ from codeclm.trainer.codec_song_pl import CodecLM_PL
11
+ from codeclm.models import CodecLM
12
+ from codeclm.models import builders
13
+
14
+ from separator import Separator
15
+
16
+
17
+ class LeVoInference(torch.nn.Module):
18
+ def __init__(self, ckpt_path):
19
+ super().__init__()
20
+
21
+ torch.backends.cudnn.enabled = False
22
+ OmegaConf.register_new_resolver("eval", lambda x: eval(x))
23
+ OmegaConf.register_new_resolver("concat", lambda *x: [xxx for xx in x for xxx in xx])
24
+ OmegaConf.register_new_resolver("get_fname", lambda: 'default')
25
+ OmegaConf.register_new_resolver("load_yaml", lambda x: list(OmegaConf.load(x)))
26
+
27
+ cfg_path = os.path.join(ckpt_path, 'config.yaml')
28
+ self.pt_path = os.path.join(ckpt_path, 'model.pt')
29
+
30
+ self.cfg = OmegaConf.load(cfg_path)
31
+ self.cfg.mode = 'inference'
32
+ self.max_duration = self.cfg.max_dur
33
+
34
+ self.default_params = dict(
35
+ top_p = 0.0,
36
+ record_tokens = True,
37
+ record_window = 50,
38
+ extend_stride = 5,
39
+ duration = self.max_duration,
40
+ )
41
+
42
+
43
+ def forward(self, lyric: str, description: str = None, prompt_audio_path: os.PathLike = None, genre: str = None, auto_prompt_path: os.PathLike = None, params = dict()):
44
+ if prompt_audio_path is not None and os.path.exists(prompt_audio_path):
45
+ separator = Separator()
46
+ audio_tokenizer = builders.get_audio_tokenizer_model(self.cfg.audio_tokenizer_checkpoint, self.cfg)
47
+ audio_tokenizer = audio_tokenizer.eval().cuda()
48
+ seperate_tokenizer = builders.get_audio_tokenizer_model(self.cfg.audio_tokenizer_checkpoint_sep, self.cfg)
49
+ seperate_tokenizer = seperate_tokenizer.eval().cuda()
50
+ pmt_wav, vocal_wav, bgm_wav = separator.run(prompt_audio_path)
51
+ pmt_wav = pmt_wav.cuda()
52
+ vocal_wav = vocal_wav.cuda()
53
+ bgm_wav = bgm_wav.cuda()
54
+ pmt_wav, _ = audio_tokenizer.encode(pmt_wav)
55
+ vocal_wav, bgm_wav = seperate_tokenizer.encode(vocal_wav, bgm_wav)
56
+ melody_is_wav = False
57
+ melody_is_wav = False
58
+ del audio_tokenizer
59
+ del seperate_tokenizer
60
+ del separator
61
+ elif genre is not None and auto_prompt_path is not None:
62
+ auto_prompt = torch.load(auto_prompt_path)
63
+ merge_prompt = [item for sublist in auto_prompt.values() for item in sublist]
64
+ if genre == "Auto":
65
+ prompt_token = merge_prompt[np.random.randint(0, len(merge_prompt))]
66
+ else:
67
+ prompt_token = auto_prompt[genre][np.random.randint(0, len(auto_prompt[genre]))]
68
+ pmt_wav = prompt_token[:,[0],:]
69
+ vocal_wav = prompt_token[:,[1],:]
70
+ bgm_wav = prompt_token[:,[2],:]
71
+ melody_is_wav = False
72
+ else:
73
+ pmt_wav = None
74
+ vocal_wav = None
75
+ bgm_wav = None
76
+ melody_is_wav = True
77
+
78
+ model_light = CodecLM_PL(self.cfg, self.pt_path)
79
+ model_light = model_light.eval()
80
+ model_light.audiolm.cfg = self.cfg
81
+ model = CodecLM(name = "tmp",
82
+ lm = model_light.audiolm,
83
+ audiotokenizer = None,
84
+ max_duration = self.max_duration,
85
+ seperate_tokenizer = None,
86
+ )
87
+ del model_light
88
+ model.lm = model.lm.cuda().to(torch.float16)
89
+ params = {**self.default_params, **params}
90
+ model.set_generation_params(**params)
91
+
92
+ generate_inp = {
93
+ 'lyrics': [lyric.replace(" ", " ")],
94
+ 'descriptions': [description],
95
+ 'melody_wavs': pmt_wav,
96
+ 'vocal_wavs': vocal_wav,
97
+ 'bgm_wavs': bgm_wav,
98
+ 'melody_is_wav': melody_is_wav,
99
+ }
100
+
101
+ with torch.autocast(device_type="cuda", dtype=torch.float16):
102
+ tokens = model.generate(**generate_inp, return_tokens=True)
103
+
104
+ del model
105
+ torch.cuda.empty_cache()
106
+
107
+ seperate_tokenizer = builders.get_audio_tokenizer_model(self.cfg.audio_tokenizer_checkpoint_sep, self.cfg)
108
+ seperate_tokenizer = seperate_tokenizer.eval().cuda()
109
+ model = CodecLM(name = "tmp",
110
+ lm = None,
111
+ audiotokenizer = None,
112
+ max_duration = self.max_duration,
113
+ seperate_tokenizer = seperate_tokenizer,
114
+ )
115
+
116
+ if tokens.shape[-1] > 3000:
117
+ tokens = tokens[..., :3000]
118
+
119
+ with torch.no_grad():
120
+ if melody_is_wav:
121
+ wav_seperate = model.generate_audio(tokens, pmt_wav, vocal_wav, bgm_wav)
122
+ else:
123
+ wav_seperate = model.generate_audio(tokens)
124
+
125
+ del seperate_tokenizer
126
+ del model
127
+ torch.cuda.empty_cache()
128
+
129
+ return wav_seperate[0]
tools/gradio/run.sh ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ export USER=root
2
+ export PYTHONDONTWRITEBYTECODE=1
3
+ export TRANSFORMERS_CACHE="$(pwd)/third_party/hub"
4
+ export NCCL_HOME=/usr/local/tccl
5
+ export PYTHONPATH="$(pwd)/codeclm/tokenizer/":"$(pwd)":"$(pwd)/codeclm/tokenizer/Flow1dVAE/":"$(pwd)/codeclm/tokenizer/":$PYTHONPATH
6
+
7
+
8
+ CKPT_PATH=$1
9
+ python3 tools/gradio/app.py $CKPT_PATH
tools/gradio/separator.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torchaudio
2
+ import os
3
+ import torch
4
+ from third_party.demucs.models.pretrained import get_model_from_yaml
5
+
6
+
7
+ class Separator(torch.nn.Module):
8
+ def __init__(self, dm_model_path='third_party/demucs/ckpt/htdemucs.pth', dm_config_path='third_party/demucs/ckpt/htdemucs.yaml', gpu_id=0) -> None:
9
+ super().__init__()
10
+ if torch.cuda.is_available() and gpu_id < torch.cuda.device_count():
11
+ self.device = torch.device(f"cuda:{gpu_id}")
12
+ else:
13
+ self.device = torch.device("cpu")
14
+ self.demucs_model = self.init_demucs_model(dm_model_path, dm_config_path)
15
+
16
+ def init_demucs_model(self, model_path, config_path):
17
+ model = get_model_from_yaml(config_path, model_path)
18
+ model.to(self.device)
19
+ model.eval()
20
+ return model
21
+
22
+ def load_audio(self, f):
23
+ a, fs = torchaudio.load(f)
24
+ if (fs != 48000):
25
+ a = torchaudio.functional.resample(a, fs, 48000)
26
+ if a.shape[-1] >= 48000*10:
27
+ a = a[..., :48000*10]
28
+ else:
29
+ a = torch.cat([a, a], -1)
30
+ return a[:, 0:48000*10]
31
+
32
+ def run(self, audio_path, output_dir='tmp', ext=".flac"):
33
+ os.makedirs(output_dir, exist_ok=True)
34
+ name, _ = os.path.splitext(os.path.split(audio_path)[-1])
35
+ output_paths = []
36
+
37
+ for stem in self.demucs_model.sources:
38
+ output_path = os.path.join(output_dir, f"{name}_{stem}{ext}")
39
+ if os.path.exists(output_path):
40
+ output_paths.append(output_path)
41
+ if len(output_paths) == 1: # 4
42
+ vocal_path = output_paths[0]
43
+ else:
44
+ drums_path, bass_path, other_path, vocal_path = self.demucs_model.separate(audio_path, output_dir, device=self.device)
45
+ for path in [drums_path, bass_path, other_path]:
46
+ os.remove(path)
47
+ full_audio = self.load_audio(audio_path)
48
+ vocal_audio = self.load_audio(vocal_path)
49
+ bgm_audio = full_audio - vocal_audio
50
+ return full_audio, vocal_audio, bgm_audio