kemuriririn commited on
Commit
887e50c
·
1 Parent(s): 590b29f
config.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import sys
4
+
5
+ import torch
6
+
7
+ from tools.i18n.i18n import I18nAuto
8
+
9
+ i18n = I18nAuto(language=os.environ.get("language", "Auto"))
10
+
11
+
12
+ pretrained_sovits_name = {
13
+ "v1": "pretrained_models/s2G488k.pth",
14
+ "v2": "pretrained_models/gsv-v2final-pretrained/s2G2333k.pth",
15
+ "v3": "pretrained_models/s2Gv3.pth", ###v3v4还要检查vocoder,算了。。。
16
+ "v4": "pretrained_models/gsv-v4-pretrained/s2Gv4.pth",
17
+ "v2Pro": "pretrained_models/v2Pro/s2Gv2Pro.pth",
18
+ "v2ProPlus": "pretrained_models/v2Pro/s2Gv2ProPlus.pth",
19
+ }
20
+
21
+ pretrained_gpt_name = {
22
+ "v1": "pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt",
23
+ "v2": "pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt",
24
+ "v3": "pretrained_models/s1v3.ckpt",
25
+ "v4": "pretrained_models/s1v3.ckpt",
26
+ "v2Pro": "pretrained_models/s1v3.ckpt",
27
+ "v2ProPlus": "pretrained_models/s1v3.ckpt",
28
+ }
29
+ name2sovits_path = {
30
+ # i18n("不训练直接推v1底模!"): "pretrained_models/s2G488k.pth",
31
+ i18n("不训练直接推v2底模!"): "pretrained_models/gsv-v2final-pretrained/s2G2333k.pth",
32
+ # i18n("不训练直接推v3底模!"): "pretrained_models/s2Gv3.pth",
33
+ # i18n("不训练直接推v4底模!"): "pretrained_models/gsv-v4-pretrained/s2Gv4.pth",
34
+ i18n("不训练直接推v2Pro底模!"): "pretrained_models/v2Pro/s2Gv2Pro.pth",
35
+ i18n("不训练直接推v2ProPlus底模!"): "pretrained_models/v2Pro/s2Gv2ProPlus.pth",
36
+ }
37
+ name2gpt_path = {
38
+ # i18n("不训练直接推v1底模!"):"pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt",
39
+ i18n(
40
+ "不训练直接推v2底模!"
41
+ ): "pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt",
42
+ i18n("不训练直接推v3底模!"): "pretrained_models/s1v3.ckpt",
43
+ }
44
+ SoVITS_weight_root = [
45
+ "SoVITS_weights",
46
+ "SoVITS_weights_v2",
47
+ "SoVITS_weights_v3",
48
+ "SoVITS_weights_v4",
49
+ "SoVITS_weights_v2Pro",
50
+ "SoVITS_weights_v2ProPlus",
51
+ ]
52
+ GPT_weight_root = [
53
+ "GPT_weights",
54
+ "GPT_weights_v2",
55
+ "GPT_weights_v3",
56
+ "GPT_weights_v4",
57
+ "GPT_weights_v2Pro",
58
+ "GPT_weights_v2ProPlus",
59
+ ]
60
+ SoVITS_weight_version2root = {
61
+ "v1": "SoVITS_weights",
62
+ "v2": "SoVITS_weights_v2",
63
+ "v3": "SoVITS_weights_v3",
64
+ "v4": "SoVITS_weights_v4",
65
+ "v2Pro": "SoVITS_weights_v2Pro",
66
+ "v2ProPlus": "SoVITS_weights_v2ProPlus",
67
+ }
68
+ GPT_weight_version2root = {
69
+ "v1": "GPT_weights",
70
+ "v2": "GPT_weights_v2",
71
+ "v3": "GPT_weights_v3",
72
+ "v4": "GPT_weights_v4",
73
+ "v2Pro": "GPT_weights_v2Pro",
74
+ "v2ProPlus": "GPT_weights_v2ProPlus",
75
+ }
76
+
77
+
78
+ def custom_sort_key(s):
79
+ # 使用正则表达式提取字符串中的数字部分和非数字部分
80
+ parts = re.split("(\d+)", s)
81
+ # 将数字部分转换为整数,非数字部分保持不变
82
+ parts = [int(part) if part.isdigit() else part for part in parts]
83
+ return parts
84
+
85
+
86
+ def get_weights_names():
87
+ SoVITS_names = []
88
+ for key in name2sovits_path:
89
+ if os.path.exists(name2sovits_path[key]):
90
+ SoVITS_names.append(key)
91
+ for path in SoVITS_weight_root:
92
+ if not os.path.exists(path):
93
+ continue
94
+ for name in os.listdir(path):
95
+ if name.endswith(".pth"):
96
+ SoVITS_names.append("%s/%s" % (path, name))
97
+ if not SoVITS_names:
98
+ SoVITS_names = [""]
99
+ GPT_names = []
100
+ for key in name2gpt_path:
101
+ if os.path.exists(name2gpt_path[key]):
102
+ GPT_names.append(key)
103
+ for path in GPT_weight_root:
104
+ if not os.path.exists(path):
105
+ continue
106
+ for name in os.listdir(path):
107
+ if name.endswith(".ckpt"):
108
+ GPT_names.append("%s/%s" % (path, name))
109
+ SoVITS_names = sorted(SoVITS_names, key=custom_sort_key)
110
+ GPT_names = sorted(GPT_names, key=custom_sort_key)
111
+ if not GPT_names:
112
+ GPT_names = [""]
113
+ return SoVITS_names, GPT_names
114
+
115
+
116
+ def change_choices():
117
+ SoVITS_names, GPT_names = get_weights_names()
118
+ return {"choices": SoVITS_names, "__type__": "update"}, {
119
+ "choices": GPT_names,
120
+ "__type__": "update",
121
+ }
122
+
123
+
124
+ # 推理用的指定模型
125
+ sovits_path = ""
126
+ gpt_path = ""
127
+ is_half_str = os.environ.get("is_half", "True")
128
+ is_half = True if is_half_str.lower() == "true" else False
129
+ is_share_str = os.environ.get("is_share", "False")
130
+ is_share = True if is_share_str.lower() == "true" else False
131
+
132
+ cnhubert_path = "pretrained_models/chinese-hubert-base"
133
+ bert_path = "pretrained_models/chinese-roberta-wwm-ext-large"
134
+ pretrained_sovits_path = "pretrained_models/s2G488k.pth"
135
+ pretrained_gpt_path = "pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt"
136
+
137
+ exp_root = "logs"
138
+ python_exec = sys.executable or "python"
139
+
140
+ webui_port_main = 9874
141
+ webui_port_uvr5 = 9873
142
+ webui_port_infer_tts = 9872
143
+ webui_port_subfix = 9871
144
+
145
+ api_port = 9880
146
+
147
+
148
+ # Thanks to the contribution of @Karasukaigan and @XXXXRT666
149
+ def get_device_dtype_sm(idx: int) -> tuple[torch.device, torch.dtype, float, float]:
150
+ cpu = torch.device("cpu")
151
+ cuda = torch.device(f"cuda:{idx}")
152
+ if not torch.cuda.is_available():
153
+ return cpu, torch.float32, 0.0, 0.0
154
+ device_idx = idx
155
+ capability = torch.cuda.get_device_capability(device_idx)
156
+ name = torch.cuda.get_device_name(device_idx)
157
+ mem_bytes = torch.cuda.get_device_properties(device_idx).total_memory
158
+ mem_gb = mem_bytes / (1024**3) + 0.4
159
+ major, minor = capability
160
+ sm_version = major + minor / 10.0
161
+ is_16_series = bool(re.search(r"16\d{2}", name)) and sm_version == 7.5
162
+ if mem_gb < 4 or sm_version < 5.3:
163
+ return cpu, torch.float32, 0.0, 0.0
164
+ if sm_version == 6.1 or is_16_series == True:
165
+ return cuda, torch.float32, sm_version, mem_gb
166
+ if sm_version > 6.1:
167
+ return cuda, torch.float16, sm_version, mem_gb
168
+ return cpu, torch.float32, 0.0, 0.0
169
+
170
+
171
+ IS_GPU = True
172
+ GPU_INFOS: list[str] = []
173
+ GPU_INDEX: set[int] = set()
174
+ GPU_COUNT = torch.cuda.device_count()
175
+ CPU_INFO: str = "0\tCPU " + i18n("CPU训练,较慢")
176
+ tmp: list[tuple[torch.device, torch.dtype, float, float]] = []
177
+ memset: set[float] = set()
178
+
179
+ for i in range(max(GPU_COUNT, 1)):
180
+ tmp.append(get_device_dtype_sm(i))
181
+
182
+ for j in tmp:
183
+ device = j[0]
184
+ memset.add(j[3])
185
+ if device.type != "cpu":
186
+ GPU_INFOS.append(f"{device.index}\t{torch.cuda.get_device_name(device.index)}")
187
+ GPU_INDEX.add(device.index)
188
+
189
+ if not GPU_INFOS:
190
+ IS_GPU = False
191
+ GPU_INFOS.append(CPU_INFO)
192
+ GPU_INDEX.add(0)
193
+
194
+ infer_device = max(tmp, key=lambda x: (x[2], x[3]))[0]
195
+ is_half = any(dtype == torch.float16 for _, dtype, _, _ in tmp)
196
+
197
+
198
+ class Config:
199
+ def __init__(self):
200
+ self.sovits_path = sovits_path
201
+ self.gpt_path = gpt_path
202
+ self.is_half = is_half
203
+
204
+ self.cnhubert_path = cnhubert_path
205
+ self.bert_path = bert_path
206
+ self.pretrained_sovits_path = pretrained_sovits_path
207
+ self.pretrained_gpt_path = pretrained_gpt_path
208
+
209
+ self.exp_root = exp_root
210
+ self.python_exec = python_exec
211
+ self.infer_device = infer_device
212
+
213
+ self.webui_port_main = webui_port_main
214
+ self.webui_port_uvr5 = webui_port_uvr5
215
+ self.webui_port_infer_tts = webui_port_infer_tts
216
+ self.webui_port_subfix = webui_port_subfix
217
+
218
+ self.api_port = api_port
inference_webui.py CHANGED
@@ -29,7 +29,7 @@ logging.getLogger("torchaudio._extension").setLevel(logging.ERROR)
29
  logging.getLogger("multipart.multipart").setLevel(logging.ERROR)
30
  warnings.simplefilter(action="ignore", category=FutureWarning)
31
 
32
- version = model_version = os.environ.get("version", "v2")
33
 
34
  from config import change_choices, get_weights_names, name2gpt_path, name2sovits_path
35
 
@@ -88,7 +88,7 @@ cnhubert.cnhubert_base_path = cnhubert_base_path
88
 
89
  import random
90
 
91
- from GPT_SoVITS.module.models import Generator, SynthesizerTrn, SynthesizerTrnV3
92
 
93
 
94
  def set_seed(seed):
 
29
  logging.getLogger("multipart.multipart").setLevel(logging.ERROR)
30
  warnings.simplefilter(action="ignore", category=FutureWarning)
31
 
32
+ version = model_version = os.environ.get("version", "v2ProPlus")
33
 
34
  from config import change_choices, get_weights_names, name2gpt_path, name2sovits_path
35
 
 
88
 
89
  import random
90
 
91
+ from module.models import Generator, SynthesizerTrn, SynthesizerTrnV3
92
 
93
 
94
  def set_seed(seed):
requirements.txt CHANGED
@@ -33,4 +33,5 @@ torch==2.4
33
  pydantic<=2.10.6
34
  torchmetrics<=1.5
35
  fast_langdetect
36
- split_lang
 
 
33
  pydantic<=2.10.6
34
  torchmetrics<=1.5
35
  fast_langdetect
36
+ split_lang
37
+ peft
tools/AP_BWE_main/24kto48k/readme.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ For the inference of the v3 model, if you find that the generated audio sounds somewhat muffled, you can try using this audio super-resolution model.
2
+ 对于v3模型的推理,如果你发现生成的音频比较闷,可以尝试这个音频超分模型。
3
+
4
+ put g_24kto48k.zip and config.json in this folder
5
+ 把g_24kto48k.zip and config.json下到这个文件夹
6
+
7
+ download link 下载链接:
8
+ https://drive.google.com/drive/folders/1IIYTf2zbJWzelu4IftKD6ooHloJ8mnZF?usp=share_link
9
+
10
+ audio sr project page 音频超分项目主页:
11
+ https://github.com/yxlu-0102/AP-BWE
tools/AP_BWE_main/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2023 Ye-Xin Lu
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
tools/AP_BWE_main/README.md ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Towards High-Quality and Efficient Speech Bandwidth Extension with Parallel Amplitude and Phase Prediction
2
+ ### Ye-Xin Lu, Yang Ai, Hui-Peng Du, Zhen-Hua Ling
3
+
4
+ **Abstract:**
5
+ Speech bandwidth extension (BWE) refers to widening the frequency bandwidth range of speech signals, enhancing the speech quality towards brighter and fuller.
6
+ This paper proposes a generative adversarial network (GAN) based BWE model with parallel prediction of Amplitude and Phase spectra, named AP-BWE, which achieves both high-quality and efficient wideband speech waveform generation.
7
+ The proposed AP-BWE generator is entirely based on convolutional neural networks (CNNs).
8
+ It features a dual-stream architecture with mutual interaction, where the amplitude stream and the phase stream communicate with each other and respectively extend the high-frequency components from the input narrowband amplitude and phase spectra.
9
+ To improve the naturalness of the extended speech signals, we employ a multi-period discriminator at the waveform level and design a pair of multi-resolution amplitude and phase discriminators at the spectral level, respectively.
10
+ Experimental results demonstrate that our proposed AP-BWE achieves state-of-the-art performance in terms of speech quality for BWE tasks targeting sampling rates of both 16 kHz and 48 kHz.
11
+ In terms of generation efficiency, due to the all-convolutional architecture and all-frame-level operations, the proposed AP-BWE can generate 48 kHz waveform samples 292.3 times faster than real-time on a single RTX 4090 GPU and 18.1 times faster than real-time on a single CPU.
12
+ Notably, to our knowledge, AP-BWE is the first to achieve the direct extension of the high-frequency phase spectrum, which is beneficial for improving the effectiveness of existing BWE methods.
13
+
14
+ **We provide our implementation as open source in this repository. Audio samples can be found at the [demo website](http://yxlu-0102.github.io/AP-BWE).**
15
+
16
+
17
+ ## Pre-requisites
18
+ 0. Python >= 3.9.
19
+ 0. Clone this repository.
20
+ 0. Install python requirements. Please refer [requirements.txt](requirements.txt).
21
+ 0. Download datasets
22
+ 1. Download and extract the [VCTK-0.92 dataset](https://datashare.ed.ac.uk/handle/10283/3443), and move its `wav48` directory into [VCTK-Corpus-0.92](VCTK-Corpus-0.92) and rename it as `wav48_origin`.
23
+ 1. Trim the silence of the dataset, and the trimmed files will be saved to `wav48_silence_trimmed`.
24
+ ```
25
+ cd VCTK-Corpus-0.92
26
+ python flac2wav.py
27
+ ```
28
+ 1. Move all the trimmed training files from `wav48_silence_trimmed` to [wav48/train](wav48/train) following the indexes in [training.txt](VCTK-Corpus-0.92/training.txt), and move all the untrimmed test files from `wav48_origin` to [wav48/test](wav48/test) following the indexes in [test.txt](VCTK-Corpus-0.92/test.txt).
29
+
30
+ ## Training
31
+ ```
32
+ cd train
33
+ CUDA_VISIBLE_DEVICES=0 python train_16k.py --config [config file path]
34
+ CUDA_VISIBLE_DEVICES=0 python train_48k.py --config [config file path]
35
+ ```
36
+ Checkpoints and copies of the configuration file are saved in the `cp_model` directory by default.<br>
37
+ You can change the path by using the `--checkpoint_path` option.
38
+ Here is an example:
39
+ ```
40
+ CUDA_VISIBLE_DEVICES=0 python train_16k.py --config ../configs/config_2kto16k.json --checkpoint_path ../checkpoints/AP-BWE_2kto16k
41
+ ```
42
+
43
+ ## Inference
44
+ ```
45
+ cd inference
46
+ python inference_16k.py --checkpoint_file [generator checkpoint file path]
47
+ python inference_48k.py --checkpoint_file [generator checkpoint file path]
48
+ ```
49
+ You can download the [pretrained weights](https://drive.google.com/drive/folders/1IIYTf2zbJWzelu4IftKD6ooHloJ8mnZF?usp=share_link) we provide and move all the files to the `checkpoints` directory.
50
+ <br>
51
+ Generated wav files are saved in `generated_files` by default.
52
+ You can change the path by adding `--output_dir` option.
53
+ Here is an example:
54
+ ```
55
+ python inference_16k.py --checkpoint_file ../checkpoints/2kto16k/g_2kto16k --output_dir ../generated_files/2kto16k
56
+ ```
57
+
58
+ ## Model Structure
59
+ ![model](Figures/model.png)
60
+
61
+ ## Comparison with other speech BWE methods
62
+ ### 2k/4k/8kHz to 16kHz
63
+ <p align="center">
64
+ <img src="Figures/table_16k.png" alt="comparison" width="90%"/>
65
+ </p>
66
+
67
+ ### 8k/12k/16/24kHz to 16kHz
68
+ <p align="center">
69
+ <img src="Figures/table_48k.png" alt="comparison" width="100%"/>
70
+ </p>
71
+
72
+ ## Acknowledgements
73
+ We referred to [HiFi-GAN](https://github.com/jik876/hifi-gan) and [NSPP](https://github.com/YangAi520/NSPP) to implement this.
74
+
75
+ ## Citation
76
+ ```
77
+ @article{lu2024towards,
78
+ title={Towards high-quality and efficient speech bandwidth extension with parallel amplitude and phase prediction},
79
+ author={Lu, Ye-Xin and Ai, Yang and Du, Hui-Peng and Ling, Zhen-Hua},
80
+ journal={arXiv preprint arXiv:2401.06387},
81
+ year={2024}
82
+ }
83
+
84
+ @inproceedings{lu2024multi,
85
+ title={Multi-Stage Speech Bandwidth Extension with Flexible Sampling Rate Control},
86
+ author={Lu, Ye-Xin and Ai, Yang and Sheng, Zheng-Yan and Ling, Zhen-Hua},
87
+ booktitle={Proc. Interspeech},
88
+ pages={2270--2274},
89
+ year={2024}
90
+ }
91
+ ```
tools/AP_BWE_main/datasets1/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+
tools/AP_BWE_main/datasets1/dataset.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import torch
4
+ import torchaudio
5
+ import torch.utils.data
6
+ import torchaudio.functional as aF
7
+
8
+
9
+ def amp_pha_stft(audio, n_fft, hop_size, win_size, center=True):
10
+ hann_window = torch.hann_window(win_size).to(audio.device)
11
+ stft_spec = torch.stft(
12
+ audio,
13
+ n_fft,
14
+ hop_length=hop_size,
15
+ win_length=win_size,
16
+ window=hann_window,
17
+ center=center,
18
+ pad_mode="reflect",
19
+ normalized=False,
20
+ return_complex=True,
21
+ )
22
+ log_amp = torch.log(torch.abs(stft_spec) + 1e-4)
23
+ pha = torch.angle(stft_spec)
24
+
25
+ com = torch.stack((torch.exp(log_amp) * torch.cos(pha), torch.exp(log_amp) * torch.sin(pha)), dim=-1)
26
+
27
+ return log_amp, pha, com
28
+
29
+
30
+ def amp_pha_istft(log_amp, pha, n_fft, hop_size, win_size, center=True):
31
+ amp = torch.exp(log_amp)
32
+ com = torch.complex(amp * torch.cos(pha), amp * torch.sin(pha))
33
+ hann_window = torch.hann_window(win_size).to(com.device)
34
+ audio = torch.istft(com, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window, center=center)
35
+
36
+ return audio
37
+
38
+
39
+ def get_dataset_filelist(a):
40
+ with open(a.input_training_file, "r", encoding="utf-8") as fi:
41
+ training_indexes = [x.split("|")[0] for x in fi.read().split("\n") if len(x) > 0]
42
+
43
+ with open(a.input_validation_file, "r", encoding="utf-8") as fi:
44
+ validation_indexes = [x.split("|")[0] for x in fi.read().split("\n") if len(x) > 0]
45
+
46
+ return training_indexes, validation_indexes
47
+
48
+
49
+ class Dataset(torch.utils.data.Dataset):
50
+ def __init__(
51
+ self,
52
+ training_indexes,
53
+ wavs_dir,
54
+ segment_size,
55
+ hr_sampling_rate,
56
+ lr_sampling_rate,
57
+ split=True,
58
+ shuffle=True,
59
+ n_cache_reuse=1,
60
+ device=None,
61
+ ):
62
+ self.audio_indexes = training_indexes
63
+ random.seed(1234)
64
+ if shuffle:
65
+ random.shuffle(self.audio_indexes)
66
+ self.wavs_dir = wavs_dir
67
+ self.segment_size = segment_size
68
+ self.hr_sampling_rate = hr_sampling_rate
69
+ self.lr_sampling_rate = lr_sampling_rate
70
+ self.split = split
71
+ self.cached_wav = None
72
+ self.n_cache_reuse = n_cache_reuse
73
+ self._cache_ref_count = 0
74
+ self.device = device
75
+
76
+ def __getitem__(self, index):
77
+ filename = self.audio_indexes[index]
78
+ if self._cache_ref_count == 0:
79
+ audio, orig_sampling_rate = torchaudio.load(os.path.join(self.wavs_dir, filename + ".wav"))
80
+ self.cached_wav = audio
81
+ self._cache_ref_count = self.n_cache_reuse
82
+ else:
83
+ audio = self.cached_wav
84
+ self._cache_ref_count -= 1
85
+
86
+ if orig_sampling_rate == self.hr_sampling_rate:
87
+ audio_hr = audio
88
+ else:
89
+ audio_hr = aF.resample(audio, orig_freq=orig_sampling_rate, new_freq=self.hr_sampling_rate)
90
+
91
+ audio_lr = aF.resample(audio, orig_freq=orig_sampling_rate, new_freq=self.lr_sampling_rate)
92
+ audio_lr = aF.resample(audio_lr, orig_freq=self.lr_sampling_rate, new_freq=self.hr_sampling_rate)
93
+ audio_lr = audio_lr[:, : audio_hr.size(1)]
94
+
95
+ if self.split:
96
+ if audio_hr.size(1) >= self.segment_size:
97
+ max_audio_start = audio_hr.size(1) - self.segment_size
98
+ audio_start = random.randint(0, max_audio_start)
99
+ audio_hr = audio_hr[:, audio_start : audio_start + self.segment_size]
100
+ audio_lr = audio_lr[:, audio_start : audio_start + self.segment_size]
101
+ else:
102
+ audio_hr = torch.nn.functional.pad(audio_hr, (0, self.segment_size - audio_hr.size(1)), "constant")
103
+ audio_lr = torch.nn.functional.pad(audio_lr, (0, self.segment_size - audio_lr.size(1)), "constant")
104
+
105
+ return (audio_hr.squeeze(), audio_lr.squeeze())
106
+
107
+ def __len__(self):
108
+ return len(self.audio_indexes)
tools/AP_BWE_main/models/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+
tools/AP_BWE_main/models/model.py ADDED
@@ -0,0 +1,464 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import torch.nn as nn
4
+ from torch.nn.utils import weight_norm, spectral_norm
5
+
6
+
7
+ # from utils import init_weights, get_padding
8
+ def get_padding(kernel_size, dilation=1):
9
+ return int((kernel_size * dilation - dilation) / 2)
10
+
11
+
12
+ def init_weights(m, mean=0.0, std=0.01):
13
+ classname = m.__class__.__name__
14
+ if classname.find("Conv") != -1:
15
+ m.weight.data.normal_(mean, std)
16
+
17
+
18
+ import numpy as np
19
+ from typing import Tuple, List
20
+
21
+ LRELU_SLOPE = 0.1
22
+
23
+
24
+ class ConvNeXtBlock(nn.Module):
25
+ """ConvNeXt Block adapted from https://github.com/facebookresearch/ConvNeXt to 1D audio signal.
26
+
27
+ Args:
28
+ dim (int): Number of input channels.
29
+ intermediate_dim (int): Dimensionality of the intermediate layer.
30
+ layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.
31
+ Defaults to None.
32
+ adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm.
33
+ None means non-conditional LayerNorm. Defaults to None.
34
+ """
35
+
36
+ def __init__(
37
+ self,
38
+ dim: int,
39
+ layer_scale_init_value=None,
40
+ adanorm_num_embeddings=None,
41
+ ):
42
+ super().__init__()
43
+ self.dwconv = nn.Conv1d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv
44
+ self.adanorm = adanorm_num_embeddings is not None
45
+
46
+ self.norm = nn.LayerNorm(dim, eps=1e-6)
47
+ self.pwconv1 = nn.Linear(dim, dim * 3) # pointwise/1x1 convs, implemented with linear layers
48
+ self.act = nn.GELU()
49
+ self.pwconv2 = nn.Linear(dim * 3, dim)
50
+ self.gamma = (
51
+ nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True)
52
+ if layer_scale_init_value > 0
53
+ else None
54
+ )
55
+
56
+ def forward(self, x, cond_embedding_id=None):
57
+ residual = x
58
+ x = self.dwconv(x)
59
+ x = x.transpose(1, 2) # (B, C, T) -> (B, T, C)
60
+ if self.adanorm:
61
+ assert cond_embedding_id is not None
62
+ x = self.norm(x, cond_embedding_id)
63
+ else:
64
+ x = self.norm(x)
65
+ x = self.pwconv1(x)
66
+ x = self.act(x)
67
+ x = self.pwconv2(x)
68
+ if self.gamma is not None:
69
+ x = self.gamma * x
70
+ x = x.transpose(1, 2) # (B, T, C) -> (B, C, T)
71
+
72
+ x = residual + x
73
+ return x
74
+
75
+
76
+ class APNet_BWE_Model(torch.nn.Module):
77
+ def __init__(self, h):
78
+ super(APNet_BWE_Model, self).__init__()
79
+ self.h = h
80
+ self.adanorm_num_embeddings = None
81
+ layer_scale_init_value = 1 / h.ConvNeXt_layers
82
+
83
+ self.conv_pre_mag = nn.Conv1d(h.n_fft // 2 + 1, h.ConvNeXt_channels, 7, 1, padding=get_padding(7, 1))
84
+ self.norm_pre_mag = nn.LayerNorm(h.ConvNeXt_channels, eps=1e-6)
85
+ self.conv_pre_pha = nn.Conv1d(h.n_fft // 2 + 1, h.ConvNeXt_channels, 7, 1, padding=get_padding(7, 1))
86
+ self.norm_pre_pha = nn.LayerNorm(h.ConvNeXt_channels, eps=1e-6)
87
+
88
+ self.convnext_mag = nn.ModuleList(
89
+ [
90
+ ConvNeXtBlock(
91
+ dim=h.ConvNeXt_channels,
92
+ layer_scale_init_value=layer_scale_init_value,
93
+ adanorm_num_embeddings=self.adanorm_num_embeddings,
94
+ )
95
+ for _ in range(h.ConvNeXt_layers)
96
+ ]
97
+ )
98
+
99
+ self.convnext_pha = nn.ModuleList(
100
+ [
101
+ ConvNeXtBlock(
102
+ dim=h.ConvNeXt_channels,
103
+ layer_scale_init_value=layer_scale_init_value,
104
+ adanorm_num_embeddings=self.adanorm_num_embeddings,
105
+ )
106
+ for _ in range(h.ConvNeXt_layers)
107
+ ]
108
+ )
109
+
110
+ self.norm_post_mag = nn.LayerNorm(h.ConvNeXt_channels, eps=1e-6)
111
+ self.norm_post_pha = nn.LayerNorm(h.ConvNeXt_channels, eps=1e-6)
112
+ self.apply(self._init_weights)
113
+ self.linear_post_mag = nn.Linear(h.ConvNeXt_channels, h.n_fft // 2 + 1)
114
+ self.linear_post_pha_r = nn.Linear(h.ConvNeXt_channels, h.n_fft // 2 + 1)
115
+ self.linear_post_pha_i = nn.Linear(h.ConvNeXt_channels, h.n_fft // 2 + 1)
116
+
117
+ def _init_weights(self, m):
118
+ if isinstance(m, (nn.Conv1d, nn.Linear)):
119
+ nn.init.trunc_normal_(m.weight, std=0.02)
120
+ nn.init.constant_(m.bias, 0)
121
+
122
+ def forward(self, mag_nb, pha_nb):
123
+ x_mag = self.conv_pre_mag(mag_nb)
124
+ x_pha = self.conv_pre_pha(pha_nb)
125
+ x_mag = self.norm_pre_mag(x_mag.transpose(1, 2)).transpose(1, 2)
126
+ x_pha = self.norm_pre_pha(x_pha.transpose(1, 2)).transpose(1, 2)
127
+
128
+ for conv_block_mag, conv_block_pha in zip(self.convnext_mag, self.convnext_pha):
129
+ x_mag = x_mag + x_pha
130
+ x_pha = x_pha + x_mag
131
+ x_mag = conv_block_mag(x_mag, cond_embedding_id=None)
132
+ x_pha = conv_block_pha(x_pha, cond_embedding_id=None)
133
+
134
+ x_mag = self.norm_post_mag(x_mag.transpose(1, 2))
135
+ mag_wb = mag_nb + self.linear_post_mag(x_mag).transpose(1, 2)
136
+
137
+ x_pha = self.norm_post_pha(x_pha.transpose(1, 2))
138
+ x_pha_r = self.linear_post_pha_r(x_pha)
139
+ x_pha_i = self.linear_post_pha_i(x_pha)
140
+ pha_wb = torch.atan2(x_pha_i, x_pha_r).transpose(1, 2)
141
+
142
+ com_wb = torch.stack((torch.exp(mag_wb) * torch.cos(pha_wb), torch.exp(mag_wb) * torch.sin(pha_wb)), dim=-1)
143
+
144
+ return mag_wb, pha_wb, com_wb
145
+
146
+
147
+ class DiscriminatorP(torch.nn.Module):
148
+ def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
149
+ super(DiscriminatorP, self).__init__()
150
+ self.period = period
151
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
152
+ self.convs = nn.ModuleList(
153
+ [
154
+ norm_f(nn.Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
155
+ norm_f(nn.Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
156
+ norm_f(nn.Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
157
+ norm_f(nn.Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
158
+ norm_f(nn.Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))),
159
+ ]
160
+ )
161
+ self.conv_post = norm_f(nn.Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
162
+
163
+ def forward(self, x):
164
+ fmap = []
165
+
166
+ # 1d to 2d
167
+ b, c, t = x.shape
168
+ if t % self.period != 0: # pad first
169
+ n_pad = self.period - (t % self.period)
170
+ x = F.pad(x, (0, n_pad), "reflect")
171
+ t = t + n_pad
172
+ x = x.view(b, c, t // self.period, self.period)
173
+
174
+ for i, l in enumerate(self.convs):
175
+ x = l(x)
176
+ x = F.leaky_relu(x, LRELU_SLOPE)
177
+ if i > 0:
178
+ fmap.append(x)
179
+ x = self.conv_post(x)
180
+ fmap.append(x)
181
+ x = torch.flatten(x, 1, -1)
182
+
183
+ return x, fmap
184
+
185
+
186
+ class MultiPeriodDiscriminator(torch.nn.Module):
187
+ def __init__(self):
188
+ super(MultiPeriodDiscriminator, self).__init__()
189
+ self.discriminators = nn.ModuleList(
190
+ [
191
+ DiscriminatorP(2),
192
+ DiscriminatorP(3),
193
+ DiscriminatorP(5),
194
+ DiscriminatorP(7),
195
+ DiscriminatorP(11),
196
+ ]
197
+ )
198
+
199
+ def forward(self, y, y_hat):
200
+ y_d_rs = []
201
+ y_d_gs = []
202
+ fmap_rs = []
203
+ fmap_gs = []
204
+ for i, d in enumerate(self.discriminators):
205
+ y_d_r, fmap_r = d(y)
206
+ y_d_g, fmap_g = d(y_hat)
207
+ y_d_rs.append(y_d_r)
208
+ fmap_rs.append(fmap_r)
209
+ y_d_gs.append(y_d_g)
210
+ fmap_gs.append(fmap_g)
211
+
212
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
213
+
214
+
215
+ class MultiResolutionAmplitudeDiscriminator(nn.Module):
216
+ def __init__(
217
+ self,
218
+ resolutions: Tuple[Tuple[int, int, int]] = ((512, 128, 512), (1024, 256, 1024), (2048, 512, 2048)),
219
+ num_embeddings: int = None,
220
+ ):
221
+ super().__init__()
222
+ self.discriminators = nn.ModuleList(
223
+ [DiscriminatorAR(resolution=r, num_embeddings=num_embeddings) for r in resolutions]
224
+ )
225
+
226
+ def forward(
227
+ self, y: torch.Tensor, y_hat: torch.Tensor, bandwidth_id: torch.Tensor = None
228
+ ) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[List[torch.Tensor]], List[List[torch.Tensor]]]:
229
+ y_d_rs = []
230
+ y_d_gs = []
231
+ fmap_rs = []
232
+ fmap_gs = []
233
+
234
+ for d in self.discriminators:
235
+ y_d_r, fmap_r = d(x=y, cond_embedding_id=bandwidth_id)
236
+ y_d_g, fmap_g = d(x=y_hat, cond_embedding_id=bandwidth_id)
237
+ y_d_rs.append(y_d_r)
238
+ fmap_rs.append(fmap_r)
239
+ y_d_gs.append(y_d_g)
240
+ fmap_gs.append(fmap_g)
241
+
242
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
243
+
244
+
245
+ class DiscriminatorAR(nn.Module):
246
+ def __init__(
247
+ self,
248
+ resolution: Tuple[int, int, int],
249
+ channels: int = 64,
250
+ in_channels: int = 1,
251
+ num_embeddings: int = None,
252
+ ):
253
+ super().__init__()
254
+ self.resolution = resolution
255
+ self.in_channels = in_channels
256
+ self.convs = nn.ModuleList(
257
+ [
258
+ weight_norm(nn.Conv2d(in_channels, channels, kernel_size=(7, 5), stride=(2, 2), padding=(3, 2))),
259
+ weight_norm(nn.Conv2d(channels, channels, kernel_size=(5, 3), stride=(2, 1), padding=(2, 1))),
260
+ weight_norm(nn.Conv2d(channels, channels, kernel_size=(5, 3), stride=(2, 2), padding=(2, 1))),
261
+ weight_norm(nn.Conv2d(channels, channels, kernel_size=3, stride=(2, 1), padding=1)),
262
+ weight_norm(nn.Conv2d(channels, channels, kernel_size=3, stride=(2, 2), padding=1)),
263
+ ]
264
+ )
265
+ if num_embeddings is not None:
266
+ self.emb = torch.nn.Embedding(num_embeddings=num_embeddings, embedding_dim=channels)
267
+ torch.nn.init.zeros_(self.emb.weight)
268
+ self.conv_post = weight_norm(nn.Conv2d(channels, 1, (3, 3), padding=(1, 1)))
269
+
270
+ def forward(
271
+ self, x: torch.Tensor, cond_embedding_id: torch.Tensor = None
272
+ ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
273
+ fmap = []
274
+ x = x.squeeze(1)
275
+
276
+ x = self.spectrogram(x)
277
+ x = x.unsqueeze(1)
278
+ for l in self.convs:
279
+ x = l(x)
280
+ x = F.leaky_relu(x, LRELU_SLOPE)
281
+ fmap.append(x)
282
+ if cond_embedding_id is not None:
283
+ emb = self.emb(cond_embedding_id)
284
+ h = (emb.view(1, -1, 1, 1) * x).sum(dim=1, keepdims=True)
285
+ else:
286
+ h = 0
287
+ x = self.conv_post(x)
288
+ fmap.append(x)
289
+ x += h
290
+ x = torch.flatten(x, 1, -1)
291
+
292
+ return x, fmap
293
+
294
+ def spectrogram(self, x: torch.Tensor) -> torch.Tensor:
295
+ n_fft, hop_length, win_length = self.resolution
296
+ amplitude_spectrogram = torch.stft(
297
+ x,
298
+ n_fft=n_fft,
299
+ hop_length=hop_length,
300
+ win_length=win_length,
301
+ window=None, # interestingly rectangular window kind of works here
302
+ center=True,
303
+ return_complex=True,
304
+ ).abs()
305
+
306
+ return amplitude_spectrogram
307
+
308
+
309
+ class MultiResolutionPhaseDiscriminator(nn.Module):
310
+ def __init__(
311
+ self,
312
+ resolutions: Tuple[Tuple[int, int, int]] = ((512, 128, 512), (1024, 256, 1024), (2048, 512, 2048)),
313
+ num_embeddings: int = None,
314
+ ):
315
+ super().__init__()
316
+ self.discriminators = nn.ModuleList(
317
+ [DiscriminatorPR(resolution=r, num_embeddings=num_embeddings) for r in resolutions]
318
+ )
319
+
320
+ def forward(
321
+ self, y: torch.Tensor, y_hat: torch.Tensor, bandwidth_id: torch.Tensor = None
322
+ ) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[List[torch.Tensor]], List[List[torch.Tensor]]]:
323
+ y_d_rs = []
324
+ y_d_gs = []
325
+ fmap_rs = []
326
+ fmap_gs = []
327
+
328
+ for d in self.discriminators:
329
+ y_d_r, fmap_r = d(x=y, cond_embedding_id=bandwidth_id)
330
+ y_d_g, fmap_g = d(x=y_hat, cond_embedding_id=bandwidth_id)
331
+ y_d_rs.append(y_d_r)
332
+ fmap_rs.append(fmap_r)
333
+ y_d_gs.append(y_d_g)
334
+ fmap_gs.append(fmap_g)
335
+
336
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
337
+
338
+
339
+ class DiscriminatorPR(nn.Module):
340
+ def __init__(
341
+ self,
342
+ resolution: Tuple[int, int, int],
343
+ channels: int = 64,
344
+ in_channels: int = 1,
345
+ num_embeddings: int = None,
346
+ ):
347
+ super().__init__()
348
+ self.resolution = resolution
349
+ self.in_channels = in_channels
350
+ self.convs = nn.ModuleList(
351
+ [
352
+ weight_norm(nn.Conv2d(in_channels, channels, kernel_size=(7, 5), stride=(2, 2), padding=(3, 2))),
353
+ weight_norm(nn.Conv2d(channels, channels, kernel_size=(5, 3), stride=(2, 1), padding=(2, 1))),
354
+ weight_norm(nn.Conv2d(channels, channels, kernel_size=(5, 3), stride=(2, 2), padding=(2, 1))),
355
+ weight_norm(nn.Conv2d(channels, channels, kernel_size=3, stride=(2, 1), padding=1)),
356
+ weight_norm(nn.Conv2d(channels, channels, kernel_size=3, stride=(2, 2), padding=1)),
357
+ ]
358
+ )
359
+ if num_embeddings is not None:
360
+ self.emb = torch.nn.Embedding(num_embeddings=num_embeddings, embedding_dim=channels)
361
+ torch.nn.init.zeros_(self.emb.weight)
362
+ self.conv_post = weight_norm(nn.Conv2d(channels, 1, (3, 3), padding=(1, 1)))
363
+
364
+ def forward(
365
+ self, x: torch.Tensor, cond_embedding_id: torch.Tensor = None
366
+ ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
367
+ fmap = []
368
+ x = x.squeeze(1)
369
+
370
+ x = self.spectrogram(x)
371
+ x = x.unsqueeze(1)
372
+ for l in self.convs:
373
+ x = l(x)
374
+ x = F.leaky_relu(x, LRELU_SLOPE)
375
+ fmap.append(x)
376
+ if cond_embedding_id is not None:
377
+ emb = self.emb(cond_embedding_id)
378
+ h = (emb.view(1, -1, 1, 1) * x).sum(dim=1, keepdims=True)
379
+ else:
380
+ h = 0
381
+ x = self.conv_post(x)
382
+ fmap.append(x)
383
+ x += h
384
+ x = torch.flatten(x, 1, -1)
385
+
386
+ return x, fmap
387
+
388
+ def spectrogram(self, x: torch.Tensor) -> torch.Tensor:
389
+ n_fft, hop_length, win_length = self.resolution
390
+ phase_spectrogram = torch.stft(
391
+ x,
392
+ n_fft=n_fft,
393
+ hop_length=hop_length,
394
+ win_length=win_length,
395
+ window=None, # interestingly rectangular window kind of works here
396
+ center=True,
397
+ return_complex=True,
398
+ ).angle()
399
+
400
+ return phase_spectrogram
401
+
402
+
403
+ def feature_loss(fmap_r, fmap_g):
404
+ loss = 0
405
+ for dr, dg in zip(fmap_r, fmap_g):
406
+ for rl, gl in zip(dr, dg):
407
+ loss += torch.mean(torch.abs(rl - gl))
408
+
409
+ return loss
410
+
411
+
412
+ def discriminator_loss(disc_real_outputs, disc_generated_outputs):
413
+ loss = 0
414
+ r_losses = []
415
+ g_losses = []
416
+ for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
417
+ r_loss = torch.mean(torch.clamp(1 - dr, min=0))
418
+ g_loss = torch.mean(torch.clamp(1 + dg, min=0))
419
+ loss += r_loss + g_loss
420
+ r_losses.append(r_loss.item())
421
+ g_losses.append(g_loss.item())
422
+
423
+ return loss, r_losses, g_losses
424
+
425
+
426
+ def generator_loss(disc_outputs):
427
+ loss = 0
428
+ gen_losses = []
429
+ for dg in disc_outputs:
430
+ l = torch.mean(torch.clamp(1 - dg, min=0))
431
+ gen_losses.append(l)
432
+ loss += l
433
+
434
+ return loss, gen_losses
435
+
436
+
437
+ def phase_losses(phase_r, phase_g):
438
+ ip_loss = torch.mean(anti_wrapping_function(phase_r - phase_g))
439
+ gd_loss = torch.mean(anti_wrapping_function(torch.diff(phase_r, dim=1) - torch.diff(phase_g, dim=1)))
440
+ iaf_loss = torch.mean(anti_wrapping_function(torch.diff(phase_r, dim=2) - torch.diff(phase_g, dim=2)))
441
+
442
+ return ip_loss, gd_loss, iaf_loss
443
+
444
+
445
+ def anti_wrapping_function(x):
446
+ return torch.abs(x - torch.round(x / (2 * np.pi)) * 2 * np.pi)
447
+
448
+
449
+ def stft_mag(audio, n_fft=2048, hop_length=512):
450
+ hann_window = torch.hann_window(n_fft).to(audio.device)
451
+ stft_spec = torch.stft(audio, n_fft, hop_length, window=hann_window, return_complex=True)
452
+ stft_mag = torch.abs(stft_spec)
453
+ return stft_mag
454
+
455
+
456
+ def cal_snr(pred, target):
457
+ snr = (20 * torch.log10(torch.norm(target, dim=-1) / torch.norm(pred - target, dim=-1).clamp(min=1e-8))).mean()
458
+ return snr
459
+
460
+
461
+ def cal_lsd(pred, target):
462
+ sp = torch.log10(stft_mag(pred).square().clamp(1e-8))
463
+ st = torch.log10(stft_mag(target).square().clamp(1e-8))
464
+ return (sp - st).square().mean(dim=1).sqrt().mean()
tools/assets.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ js = """
2
+ function deleteTheme() {
3
+
4
+ const params = new URLSearchParams(window.location.search);
5
+ if (params.has('__theme')) {
6
+ params.delete('__theme');
7
+ const newUrl = `${window.location.pathname}?${params.toString()}`;
8
+ window.location.replace(newUrl);
9
+ }
10
+
11
+ }
12
+ """
13
+
14
+ css = """
15
+ /* CSSStyleRule */
16
+ .markdown {
17
+ padding: 6px 10px;
18
+ }
19
+
20
+ @media (prefers-color-scheme: light) {
21
+ .markdown {
22
+ background-color: lightblue;
23
+ color: #000;
24
+ }
25
+ }
26
+
27
+ @media (prefers-color-scheme: dark) {
28
+ .markdown {
29
+ background-color: #4b4b4b;
30
+ color: rgb(244, 244, 245);
31
+ }
32
+ }
33
+
34
+ ::selection {
35
+ background: #ffc078 !important;
36
+ }
37
+
38
+ footer {
39
+ height: 50px !important; /* 设置页脚高度 */
40
+ background-color: transparent !important; /* 背景透明 */
41
+ display: flex;
42
+ justify-content: center; /* 居中对齐 */
43
+ align-items: center; /* 垂直居中 */
44
+ }
45
+
46
+ footer * {
47
+ display: none !important; /* 隐藏所有子元素 */
48
+ }
49
+
50
+ """
51
+
52
+ top_html = """
53
+ <div align="center">
54
+ <div style="margin-bottom: 5px; font-size: 15px;">{}</div>
55
+ <div style="display: flex; gap: 80px; justify-content: center;">
56
+ <a href="https://github.com/RVC-Boss/GPT-SoVITS" target="_blank">
57
+ <img src="https://img.shields.io/badge/GitHub-GPT--SoVITS-blue.svg?style=for-the-badge&logo=github" style="width: auto; height: 30px;">
58
+ </a>
59
+ <a href="https://www.yuque.com/baicaigongchang1145haoyuangong/ib3g1e" target="_blank">
60
+ <img src="https://img.shields.io/badge/简体中文-阅读文档-blue?style=for-the-badge&logo=googledocs&logoColor=white" style="width: auto; height: 30px;">
61
+ </a>
62
+ <a href="https://www.yuque.com/baicaigongchang1145haoyuangong/ib3g1e" target="_blank">
63
+ <img src="https://img.shields.io/badge/English-READ%20DOCS-blue?style=for-the-badge&logo=googledocs&logoColor=white" style="width: auto; height: 30px;">
64
+ </a>
65
+ <a href="https://github.com/RVC-Boss/GPT-SoVITS/blob/main/LICENSE" target="_blank">
66
+ <img src="https://img.shields.io/badge/LICENSE-MIT-green.svg?style=for-the-badge&logo=opensourceinitiative" style="width: auto; height: 30px;">
67
+ </a>
68
+ </div>
69
+ </div>
70
+ """
tools/audio_sr.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import absolute_import, division, print_function, unicode_literals
2
+ import sys
3
+ import os
4
+
5
+ AP_BWE_main_dir_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "AP_BWE_main")
6
+ sys.path.append(AP_BWE_main_dir_path)
7
+ import json
8
+ import torch
9
+ import torchaudio.functional as aF
10
+ # from attrdict import AttrDict####will be bug in py3.10
11
+
12
+ from datasets1.dataset import amp_pha_stft, amp_pha_istft
13
+ from models.model import APNet_BWE_Model
14
+
15
+
16
+ class AP_BWE:
17
+ def __init__(self, device, DictToAttrRecursive, checkpoint_file=None):
18
+ if checkpoint_file == None:
19
+ checkpoint_file = "%s/24kto48k/g_24kto48k.zip" % (AP_BWE_main_dir_path)
20
+ if os.path.exists(checkpoint_file) == False:
21
+ raise FileNotFoundError
22
+ config_file = os.path.join(os.path.split(checkpoint_file)[0], "config.json")
23
+ with open(config_file) as f:
24
+ data = f.read()
25
+ json_config = json.loads(data)
26
+ # h = AttrDict(json_config)
27
+ h = DictToAttrRecursive(json_config)
28
+ model = APNet_BWE_Model(h).to(device)
29
+ state_dict = torch.load(checkpoint_file, map_location="cpu", weights_only=False)
30
+ model.load_state_dict(state_dict["generator"])
31
+ model.eval()
32
+ self.device = device
33
+ self.model = model
34
+ self.h = h
35
+
36
+ def to(self, *arg, **kwargs):
37
+ self.model.to(*arg, **kwargs)
38
+ self.device = self.model.conv_pre_mag.weight.device
39
+ return self
40
+
41
+ def __call__(self, audio, orig_sampling_rate):
42
+ with torch.no_grad():
43
+ # audio, orig_sampling_rate = torchaudio.load(inp_path)
44
+ # audio = audio.to(self.device)
45
+ audio = aF.resample(audio, orig_freq=orig_sampling_rate, new_freq=self.h.hr_sampling_rate)
46
+ amp_nb, pha_nb, com_nb = amp_pha_stft(audio, self.h.n_fft, self.h.hop_size, self.h.win_size)
47
+ amp_wb_g, pha_wb_g, com_wb_g = self.model(amp_nb, pha_nb)
48
+ audio_hr_g = amp_pha_istft(amp_wb_g, pha_wb_g, self.h.n_fft, self.h.hop_size, self.h.win_size)
49
+ # sf.write(opt_path, audio_hr_g.squeeze().cpu().numpy(), self.h.hr_sampling_rate, 'PCM_16')
50
+ return audio_hr_g.squeeze().cpu().numpy(), self.h.hr_sampling_rate
weight.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"GPT": {}, "SoVITS": {}}