lj1995 commited on
Commit
0a5b75c
·
1 Parent(s): 912e20f

Upload folder using huggingface_hub

Browse files
README.md CHANGED
@@ -1,10 +1,10 @@
1
  ---
2
- title: GPT SoVITS V2
3
  emoji: 🤗
4
  colorFrom: indigo
5
  colorTo: red
6
  sdk: gradio
7
- sdk_version: 4.24.0
8
  app_file: inference_webui.py
9
  pinned: false
10
  license: mit
 
1
  ---
2
+ title: GPT SoVITS V2 Pro Plus
3
  emoji: 🤗
4
  colorFrom: indigo
5
  colorTo: red
6
  sdk: gradio
7
+ sdk_version: 4.44.1
8
  app_file: inference_webui.py
9
  pinned: false
10
  license: mit
eres2net/ERes2Net.py ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker). All Rights Reserved.
2
+ # Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
3
+
4
+ """
5
+ Res2Net implementation is adapted from https://github.com/wenet-e2e/wespeaker.
6
+ ERes2Net incorporates both local and global feature fusion techniques to improve the performance.
7
+ The local feature fusion (LFF) fuses the features within one single residual block to extract the local signal.
8
+ The global feature fusion (GFF) takes acoustic features of different scales as input to aggregate global signal.
9
+ """
10
+
11
+
12
+ import torch
13
+ import math
14
+ import torch.nn as nn
15
+ import torch.nn.functional as F
16
+ import pooling_layers as pooling_layers
17
+ from fusion import AFF
18
+
19
+ class ReLU(nn.Hardtanh):
20
+
21
+ def __init__(self, inplace=False):
22
+ super(ReLU, self).__init__(0, 20, inplace)
23
+
24
+ def __repr__(self):
25
+ inplace_str = 'inplace' if self.inplace else ''
26
+ return self.__class__.__name__ + ' (' \
27
+ + inplace_str + ')'
28
+
29
+
30
+ class BasicBlockERes2Net(nn.Module):
31
+ expansion = 2
32
+
33
+ def __init__(self, in_planes, planes, stride=1, baseWidth=32, scale=2):
34
+ super(BasicBlockERes2Net, self).__init__()
35
+ width = int(math.floor(planes*(baseWidth/64.0)))
36
+ self.conv1 = nn.Conv2d(in_planes, width*scale, kernel_size=1, stride=stride, bias=False)
37
+ self.bn1 = nn.BatchNorm2d(width*scale)
38
+ self.nums = scale
39
+
40
+ convs=[]
41
+ bns=[]
42
+ for i in range(self.nums):
43
+ convs.append(nn.Conv2d(width, width, kernel_size=3, padding=1, bias=False))
44
+ bns.append(nn.BatchNorm2d(width))
45
+ self.convs = nn.ModuleList(convs)
46
+ self.bns = nn.ModuleList(bns)
47
+ self.relu = ReLU(inplace=True)
48
+
49
+ self.conv3 = nn.Conv2d(width*scale, planes*self.expansion, kernel_size=1, bias=False)
50
+ self.bn3 = nn.BatchNorm2d(planes*self.expansion)
51
+ self.shortcut = nn.Sequential()
52
+ if stride != 1 or in_planes != self.expansion * planes:
53
+ self.shortcut = nn.Sequential(
54
+ nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1,
55
+ stride=stride, bias=False),
56
+ nn.BatchNorm2d(self.expansion * planes))
57
+ self.stride = stride
58
+ self.width = width
59
+ self.scale = scale
60
+
61
+ def forward(self, x):
62
+ residual = x
63
+
64
+ out = self.conv1(x)
65
+ out = self.bn1(out)
66
+ out = self.relu(out)
67
+ spx = torch.split(out,self.width,1)
68
+ for i in range(self.nums):
69
+ if i==0:
70
+ sp = spx[i]
71
+ else:
72
+ sp = sp + spx[i]
73
+ sp = self.convs[i](sp)
74
+ sp = self.relu(self.bns[i](sp))
75
+ if i==0:
76
+ out = sp
77
+ else:
78
+ out = torch.cat((out,sp),1)
79
+
80
+ out = self.conv3(out)
81
+ out = self.bn3(out)
82
+
83
+ residual = self.shortcut(x)
84
+ out += residual
85
+ out = self.relu(out)
86
+
87
+ return out
88
+
89
+ class BasicBlockERes2Net_diff_AFF(nn.Module):
90
+ expansion = 2
91
+
92
+ def __init__(self, in_planes, planes, stride=1, baseWidth=32, scale=2):
93
+ super(BasicBlockERes2Net_diff_AFF, self).__init__()
94
+ width = int(math.floor(planes*(baseWidth/64.0)))
95
+ self.conv1 = nn.Conv2d(in_planes, width*scale, kernel_size=1, stride=stride, bias=False)
96
+ self.bn1 = nn.BatchNorm2d(width*scale)
97
+ self.nums = scale
98
+
99
+ convs=[]
100
+ fuse_models=[]
101
+ bns=[]
102
+ for i in range(self.nums):
103
+ convs.append(nn.Conv2d(width, width, kernel_size=3, padding=1, bias=False))
104
+ bns.append(nn.BatchNorm2d(width))
105
+ for j in range(self.nums - 1):
106
+ fuse_models.append(AFF(channels=width))
107
+
108
+ self.convs = nn.ModuleList(convs)
109
+ self.bns = nn.ModuleList(bns)
110
+ self.fuse_models = nn.ModuleList(fuse_models)
111
+ self.relu = ReLU(inplace=True)
112
+
113
+ self.conv3 = nn.Conv2d(width*scale, planes*self.expansion, kernel_size=1, bias=False)
114
+ self.bn3 = nn.BatchNorm2d(planes*self.expansion)
115
+ self.shortcut = nn.Sequential()
116
+ if stride != 1 or in_planes != self.expansion * planes:
117
+ self.shortcut = nn.Sequential(
118
+ nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1,
119
+ stride=stride, bias=False),
120
+ nn.BatchNorm2d(self.expansion * planes))
121
+ self.stride = stride
122
+ self.width = width
123
+ self.scale = scale
124
+
125
+ def forward(self, x):
126
+ residual = x
127
+
128
+ out = self.conv1(x)
129
+ out = self.bn1(out)
130
+ out = self.relu(out)
131
+ spx = torch.split(out,self.width,1)
132
+ for i in range(self.nums):
133
+ if i==0:
134
+ sp = spx[i]
135
+ else:
136
+ sp = self.fuse_models[i-1](sp, spx[i])
137
+
138
+ sp = self.convs[i](sp)
139
+ sp = self.relu(self.bns[i](sp))
140
+ if i==0:
141
+ out = sp
142
+ else:
143
+ out = torch.cat((out,sp),1)
144
+
145
+ out = self.conv3(out)
146
+ out = self.bn3(out)
147
+
148
+ residual = self.shortcut(x)
149
+ out += residual
150
+ out = self.relu(out)
151
+
152
+ return out
153
+
154
+ class ERes2Net(nn.Module):
155
+ def __init__(self,
156
+ block=BasicBlockERes2Net,
157
+ block_fuse=BasicBlockERes2Net_diff_AFF,
158
+ num_blocks=[3, 4, 6, 3],
159
+ m_channels=32,
160
+ feat_dim=80,
161
+ embedding_size=192,
162
+ pooling_func='TSTP',
163
+ two_emb_layer=False):
164
+ super(ERes2Net, self).__init__()
165
+ self.in_planes = m_channels
166
+ self.feat_dim = feat_dim
167
+ self.embedding_size = embedding_size
168
+ self.stats_dim = int(feat_dim / 8) * m_channels * 8
169
+ self.two_emb_layer = two_emb_layer
170
+
171
+ self.conv1 = nn.Conv2d(1, m_channels, kernel_size=3, stride=1, padding=1, bias=False)
172
+ self.bn1 = nn.BatchNorm2d(m_channels)
173
+ self.layer1 = self._make_layer(block, m_channels, num_blocks[0], stride=1)
174
+ self.layer2 = self._make_layer(block, m_channels * 2, num_blocks[1], stride=2)
175
+ self.layer3 = self._make_layer(block_fuse, m_channels * 4, num_blocks[2], stride=2)
176
+ self.layer4 = self._make_layer(block_fuse, m_channels * 8, num_blocks[3], stride=2)
177
+
178
+ # Downsampling module for each layer
179
+ self.layer1_downsample = nn.Conv2d(m_channels * 2, m_channels * 4, kernel_size=3, stride=2, padding=1, bias=False)
180
+ self.layer2_downsample = nn.Conv2d(m_channels * 4, m_channels * 8, kernel_size=3, padding=1, stride=2, bias=False)
181
+ self.layer3_downsample = nn.Conv2d(m_channels * 8, m_channels * 16, kernel_size=3, padding=1, stride=2, bias=False)
182
+
183
+ # Bottom-up fusion module
184
+ self.fuse_mode12 = AFF(channels=m_channels * 4)
185
+ self.fuse_mode123 = AFF(channels=m_channels * 8)
186
+ self.fuse_mode1234 = AFF(channels=m_channels * 16)
187
+
188
+ self.n_stats = 1 if pooling_func == 'TAP' or pooling_func == "TSDP" else 2
189
+ self.pool = getattr(pooling_layers, pooling_func)(
190
+ in_dim=self.stats_dim * block.expansion)
191
+ self.seg_1 = nn.Linear(self.stats_dim * block.expansion * self.n_stats,
192
+ embedding_size)
193
+ if self.two_emb_layer:
194
+ self.seg_bn_1 = nn.BatchNorm1d(embedding_size, affine=False)
195
+ self.seg_2 = nn.Linear(embedding_size, embedding_size)
196
+ else:
197
+ self.seg_bn_1 = nn.Identity()
198
+ self.seg_2 = nn.Identity()
199
+
200
+ def _make_layer(self, block, planes, num_blocks, stride):
201
+ strides = [stride] + [1] * (num_blocks - 1)
202
+ layers = []
203
+ for stride in strides:
204
+ layers.append(block(self.in_planes, planes, stride))
205
+ self.in_planes = planes * block.expansion
206
+ return nn.Sequential(*layers)
207
+
208
+ def forward(self, x):
209
+ x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T)
210
+ x = x.unsqueeze_(1)
211
+ out = F.relu(self.bn1(self.conv1(x)))
212
+ out1 = self.layer1(out)
213
+ out2 = self.layer2(out1)
214
+ out1_downsample = self.layer1_downsample(out1)
215
+ fuse_out12 = self.fuse_mode12(out2, out1_downsample)
216
+ out3 = self.layer3(out2)
217
+ fuse_out12_downsample = self.layer2_downsample(fuse_out12)
218
+ fuse_out123 = self.fuse_mode123(out3, fuse_out12_downsample)
219
+ out4 = self.layer4(out3)
220
+ fuse_out123_downsample = self.layer3_downsample(fuse_out123)
221
+ fuse_out1234 = self.fuse_mode1234(out4, fuse_out123_downsample)
222
+ stats = self.pool(fuse_out1234)
223
+
224
+ embed_a = self.seg_1(stats)
225
+ if self.two_emb_layer:
226
+ out = F.relu(embed_a)
227
+ out = self.seg_bn_1(out)
228
+ embed_b = self.seg_2(out)
229
+ return embed_b
230
+ else:
231
+ return embed_a
232
+
233
+ def forward3(self, x):
234
+ x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T)
235
+ x = x.unsqueeze_(1)
236
+ out = F.relu(self.bn1(self.conv1(x)))
237
+ out1 = self.layer1(out)
238
+ out2 = self.layer2(out1)
239
+ out1_downsample = self.layer1_downsample(out1)
240
+ fuse_out12 = self.fuse_mode12(out2, out1_downsample)
241
+ out3 = self.layer3(out2)
242
+ fuse_out12_downsample = self.layer2_downsample(fuse_out12)
243
+ fuse_out123 = self.fuse_mode123(out3, fuse_out12_downsample)
244
+ out4 = self.layer4(out3)
245
+ fuse_out123_downsample = self.layer3_downsample(fuse_out123)
246
+ fuse_out1234 = self.fuse_mode1234(out4, fuse_out123_downsample).flatten(start_dim=1,end_dim=2).mean(-1)
247
+ return fuse_out1234
248
+
249
+
250
+ if __name__ == '__main__':
251
+
252
+ x = torch.zeros(10, 300, 80)
253
+ model = ERes2Net(feat_dim=80, embedding_size=192, pooling_func='TSTP')
254
+ model.eval()
255
+ out = model(x)
256
+ print(out.shape) # torch.Size([10, 192])
257
+
258
+ num_params = sum(param.numel() for param in model.parameters())
259
+ print("{} M".format(num_params / 1e6)) # 6.61M
260
+
eres2net/ERes2NetV2.py ADDED
@@ -0,0 +1,292 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker). All Rights Reserved.
2
+ # Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
3
+
4
+ """
5
+ To further improve the short-duration feature extraction capability of ERes2Net, we expand the channel dimension
6
+ within each stage. However, this modification also increases the number of model parameters and computational complexity.
7
+ To alleviate this problem, we propose an improved ERes2NetV2 by pruning redundant structures, ultimately reducing
8
+ both the model parameters and its computational cost.
9
+ """
10
+
11
+
12
+
13
+ import torch
14
+ import math
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+ import pooling_layers as pooling_layers
18
+ from fusion import AFF
19
+
20
+ class ReLU(nn.Hardtanh):
21
+
22
+ def __init__(self, inplace=False):
23
+ super(ReLU, self).__init__(0, 20, inplace)
24
+
25
+ def __repr__(self):
26
+ inplace_str = 'inplace' if self.inplace else ''
27
+ return self.__class__.__name__ + ' (' \
28
+ + inplace_str + ')'
29
+
30
+
31
+ class BasicBlockERes2NetV2(nn.Module):
32
+
33
+ def __init__(self, in_planes, planes, stride=1, baseWidth=26, scale=2, expansion=2):
34
+ super(BasicBlockERes2NetV2, self).__init__()
35
+ width = int(math.floor(planes*(baseWidth/64.0)))
36
+ self.conv1 = nn.Conv2d(in_planes, width*scale, kernel_size=1, stride=stride, bias=False)
37
+ self.bn1 = nn.BatchNorm2d(width*scale)
38
+ self.nums = scale
39
+ self.expansion = expansion
40
+
41
+ convs=[]
42
+ bns=[]
43
+ for i in range(self.nums):
44
+ convs.append(nn.Conv2d(width, width, kernel_size=3, padding=1, bias=False))
45
+ bns.append(nn.BatchNorm2d(width))
46
+ self.convs = nn.ModuleList(convs)
47
+ self.bns = nn.ModuleList(bns)
48
+ self.relu = ReLU(inplace=True)
49
+
50
+ self.conv3 = nn.Conv2d(width*scale, planes*self.expansion, kernel_size=1, bias=False)
51
+ self.bn3 = nn.BatchNorm2d(planes*self.expansion)
52
+ self.shortcut = nn.Sequential()
53
+ if stride != 1 or in_planes != self.expansion * planes:
54
+ self.shortcut = nn.Sequential(
55
+ nn.Conv2d(in_planes,
56
+ self.expansion * planes,
57
+ kernel_size=1,
58
+ stride=stride,
59
+ bias=False),
60
+ nn.BatchNorm2d(self.expansion * planes))
61
+ self.stride = stride
62
+ self.width = width
63
+ self.scale = scale
64
+
65
+ def forward(self, x):
66
+ residual = x
67
+
68
+ out = self.conv1(x)
69
+ out = self.bn1(out)
70
+ out = self.relu(out)
71
+ spx = torch.split(out,self.width,1)
72
+ for i in range(self.nums):
73
+ if i==0:
74
+ sp = spx[i]
75
+ else:
76
+ sp = sp + spx[i]
77
+ sp = self.convs[i](sp)
78
+ sp = self.relu(self.bns[i](sp))
79
+ if i==0:
80
+ out = sp
81
+ else:
82
+ out = torch.cat((out,sp),1)
83
+
84
+ out = self.conv3(out)
85
+ out = self.bn3(out)
86
+
87
+ residual = self.shortcut(x)
88
+ out += residual
89
+ out = self.relu(out)
90
+
91
+ return out
92
+
93
+ class BasicBlockERes2NetV2AFF(nn.Module):
94
+
95
+ def __init__(self, in_planes, planes, stride=1, baseWidth=26, scale=2, expansion=2):
96
+ super(BasicBlockERes2NetV2AFF, self).__init__()
97
+ width = int(math.floor(planes*(baseWidth/64.0)))
98
+ self.conv1 = nn.Conv2d(in_planes, width*scale, kernel_size=1, stride=stride, bias=False)
99
+ self.bn1 = nn.BatchNorm2d(width*scale)
100
+ self.nums = scale
101
+ self.expansion = expansion
102
+
103
+ convs=[]
104
+ fuse_models=[]
105
+ bns=[]
106
+ for i in range(self.nums):
107
+ convs.append(nn.Conv2d(width, width, kernel_size=3, padding=1, bias=False))
108
+ bns.append(nn.BatchNorm2d(width))
109
+ for j in range(self.nums - 1):
110
+ fuse_models.append(AFF(channels=width, r=4))
111
+
112
+ self.convs = nn.ModuleList(convs)
113
+ self.bns = nn.ModuleList(bns)
114
+ self.fuse_models = nn.ModuleList(fuse_models)
115
+ self.relu = ReLU(inplace=True)
116
+
117
+ self.conv3 = nn.Conv2d(width*scale, planes*self.expansion, kernel_size=1, bias=False)
118
+ self.bn3 = nn.BatchNorm2d(planes*self.expansion)
119
+ self.shortcut = nn.Sequential()
120
+ if stride != 1 or in_planes != self.expansion * planes:
121
+ self.shortcut = nn.Sequential(
122
+ nn.Conv2d(in_planes,
123
+ self.expansion * planes,
124
+ kernel_size=1,
125
+ stride=stride,
126
+ bias=False),
127
+ nn.BatchNorm2d(self.expansion * planes))
128
+ self.stride = stride
129
+ self.width = width
130
+ self.scale = scale
131
+
132
+ def forward(self, x):
133
+ residual = x
134
+
135
+ out = self.conv1(x)
136
+ out = self.bn1(out)
137
+ out = self.relu(out)
138
+ spx = torch.split(out,self.width,1)
139
+ for i in range(self.nums):
140
+ if i==0:
141
+ sp = spx[i]
142
+ else:
143
+ sp = self.fuse_models[i-1](sp, spx[i])
144
+
145
+ sp = self.convs[i](sp)
146
+ sp = self.relu(self.bns[i](sp))
147
+ if i==0:
148
+ out = sp
149
+ else:
150
+ out = torch.cat((out,sp),1)
151
+
152
+ out = self.conv3(out)
153
+ out = self.bn3(out)
154
+
155
+ residual = self.shortcut(x)
156
+ out += residual
157
+ out = self.relu(out)
158
+
159
+ return out
160
+
161
+ class ERes2NetV2(nn.Module):
162
+ def __init__(self,
163
+ block=BasicBlockERes2NetV2,
164
+ block_fuse=BasicBlockERes2NetV2AFF,
165
+ num_blocks=[3, 4, 6, 3],
166
+ m_channels=64,
167
+ feat_dim=80,
168
+ embedding_size=192,
169
+ baseWidth=26,
170
+ scale=2,
171
+ expansion=2,
172
+ pooling_func='TSTP',
173
+ two_emb_layer=False):
174
+ super(ERes2NetV2, self).__init__()
175
+ self.in_planes = m_channels
176
+ self.feat_dim = feat_dim
177
+ self.embedding_size = embedding_size
178
+ self.stats_dim = int(feat_dim / 8) * m_channels * 8
179
+ self.two_emb_layer = two_emb_layer
180
+ self.baseWidth = baseWidth
181
+ self.scale = scale
182
+ self.expansion = expansion
183
+
184
+ self.conv1 = nn.Conv2d(1,
185
+ m_channels,
186
+ kernel_size=3,
187
+ stride=1,
188
+ padding=1,
189
+ bias=False)
190
+ self.bn1 = nn.BatchNorm2d(m_channels)
191
+ self.layer1 = self._make_layer(block,
192
+ m_channels,
193
+ num_blocks[0],
194
+ stride=1)
195
+ self.layer2 = self._make_layer(block,
196
+ m_channels * 2,
197
+ num_blocks[1],
198
+ stride=2)
199
+ self.layer3 = self._make_layer(block_fuse,
200
+ m_channels * 4,
201
+ num_blocks[2],
202
+ stride=2)
203
+ self.layer4 = self._make_layer(block_fuse,
204
+ m_channels * 8,
205
+ num_blocks[3],
206
+ stride=2)
207
+
208
+ # Downsampling module
209
+ self.layer3_ds = nn.Conv2d(m_channels * 4 * self.expansion, m_channels * 8 * self.expansion, kernel_size=3, \
210
+ padding=1, stride=2, bias=False)
211
+
212
+ # Bottom-up fusion module
213
+ self.fuse34 = AFF(channels=m_channels * 8 * self.expansion, r=4)
214
+
215
+ self.n_stats = 1 if pooling_func == 'TAP' or pooling_func == "TSDP" else 2
216
+ self.pool = getattr(pooling_layers, pooling_func)(
217
+ in_dim=self.stats_dim * self.expansion)
218
+ self.seg_1 = nn.Linear(self.stats_dim * self.expansion * self.n_stats,
219
+ embedding_size)
220
+ if self.two_emb_layer:
221
+ self.seg_bn_1 = nn.BatchNorm1d(embedding_size, affine=False)
222
+ self.seg_2 = nn.Linear(embedding_size, embedding_size)
223
+ else:
224
+ self.seg_bn_1 = nn.Identity()
225
+ self.seg_2 = nn.Identity()
226
+
227
+ def _make_layer(self, block, planes, num_blocks, stride):
228
+ strides = [stride] + [1] * (num_blocks - 1)
229
+ layers = []
230
+ for stride in strides:
231
+ layers.append(block(self.in_planes, planes, stride, baseWidth=self.baseWidth, scale=self.scale, expansion=self.expansion))
232
+ self.in_planes = planes * self.expansion
233
+ return nn.Sequential(*layers)
234
+
235
+ def forward(self, x):
236
+ x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T)
237
+ x = x.unsqueeze_(1)
238
+ out = F.relu(self.bn1(self.conv1(x)))
239
+ out1 = self.layer1(out)
240
+ out2 = self.layer2(out1)
241
+ out3 = self.layer3(out2)
242
+ out4 = self.layer4(out3)
243
+ out3_ds = self.layer3_ds(out3)
244
+ fuse_out34 = self.fuse34(out4, out3_ds)
245
+ stats = self.pool(fuse_out34)
246
+
247
+ embed_a = self.seg_1(stats)
248
+ if self.two_emb_layer:
249
+ out = F.relu(embed_a)
250
+ out = self.seg_bn_1(out)
251
+ embed_b = self.seg_2(out)
252
+ return embed_b
253
+ else:
254
+ return embed_a
255
+
256
+ def forward3(self, x):
257
+ x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T)
258
+ x = x.unsqueeze_(1)
259
+ out = F.relu(self.bn1(self.conv1(x)))
260
+ out1 = self.layer1(out)
261
+ out2 = self.layer2(out1)
262
+ out3 = self.layer3(out2)
263
+ out4 = self.layer4(out3)
264
+ out3_ds = self.layer3_ds(out3)
265
+ fuse_out34 = self.fuse34(out4, out3_ds)
266
+ # print(111111111,fuse_out34.shape)#111111111 torch.Size([16, 2048, 10, 72])
267
+ return fuse_out34.flatten(start_dim=1,end_dim=2).mean(-1)
268
+ # stats = self.pool(fuse_out34)
269
+ #
270
+ # embed_a = self.seg_1(stats)
271
+ # if self.two_emb_layer:
272
+ # out = F.relu(embed_a)
273
+ # out = self.seg_bn_1(out)
274
+ # embed_b = self.seg_2(out)
275
+ # return embed_b
276
+ # else:
277
+ # return embed_a
278
+
279
+ if __name__ == '__main__':
280
+
281
+ x = torch.randn(1, 300, 80)
282
+ model = ERes2NetV2(feat_dim=80, embedding_size=192, m_channels=64, baseWidth=26, scale=2, expansion=2)
283
+ model.eval()
284
+ y = model(x)
285
+ print(y.size())
286
+ macs, num_params = profile(model, inputs=(x, ))
287
+ print("Params: {} M".format(num_params / 1e6)) # 17.86 M
288
+ print("MACs: {} G".format(macs / 1e9)) # 12.69 G
289
+
290
+
291
+
292
+
eres2net/ERes2Net_huge.py ADDED
@@ -0,0 +1,286 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker). All Rights Reserved.
2
+ # Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
3
+
4
+ """ Res2Net implementation is adapted from https://github.com/wenet-e2e/wespeaker.
5
+ ERes2Net incorporates both local and global feature fusion techniques to improve the performance.
6
+ The local feature fusion (LFF) fuses the features within one single residual block to extract the local signal.
7
+ The global feature fusion (GFF) takes acoustic features of different scales as input to aggregate global signal.
8
+ ERes2Net-huge is an upgraded version of ERes2Net that uses a larger number of parameters to achieve better
9
+ recognition performance. Parameters expansion, baseWidth, and scale can be modified to obtain optimal performance.
10
+ """
11
+ import pdb
12
+
13
+ import torch
14
+ import math
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+ import pooling_layers as pooling_layers
18
+ from fusion import AFF
19
+
20
+ class ReLU(nn.Hardtanh):
21
+
22
+ def __init__(self, inplace=False):
23
+ super(ReLU, self).__init__(0, 20, inplace)
24
+
25
+ def __repr__(self):
26
+ inplace_str = 'inplace' if self.inplace else ''
27
+ return self.__class__.__name__ + ' (' \
28
+ + inplace_str + ')'
29
+
30
+
31
+ class BasicBlockERes2Net(nn.Module):
32
+ expansion = 4
33
+
34
+ def __init__(self, in_planes, planes, stride=1, baseWidth=24, scale=3):
35
+ super(BasicBlockERes2Net, self).__init__()
36
+ width = int(math.floor(planes*(baseWidth/64.0)))
37
+ self.conv1 = nn.Conv2d(in_planes, width*scale, kernel_size=1, stride=stride, bias=False)
38
+ self.bn1 = nn.BatchNorm2d(width*scale)
39
+ self.nums = scale
40
+
41
+ convs=[]
42
+ bns=[]
43
+ for i in range(self.nums):
44
+ convs.append(nn.Conv2d(width, width, kernel_size=3, padding=1, bias=False))
45
+ bns.append(nn.BatchNorm2d(width))
46
+ self.convs = nn.ModuleList(convs)
47
+ self.bns = nn.ModuleList(bns)
48
+ self.relu = ReLU(inplace=True)
49
+
50
+ self.conv3 = nn.Conv2d(width*scale, planes*self.expansion, kernel_size=1, bias=False)
51
+ self.bn3 = nn.BatchNorm2d(planes*self.expansion)
52
+ self.shortcut = nn.Sequential()
53
+ if stride != 1 or in_planes != self.expansion * planes:
54
+ self.shortcut = nn.Sequential(
55
+ nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
56
+ nn.BatchNorm2d(self.expansion * planes))
57
+ self.stride = stride
58
+ self.width = width
59
+ self.scale = scale
60
+
61
+ def forward(self, x):
62
+ residual = x
63
+
64
+ out = self.conv1(x)
65
+ out = self.bn1(out)
66
+ out = self.relu(out)
67
+ spx = torch.split(out,self.width,1)
68
+ for i in range(self.nums):
69
+ if i==0:
70
+ sp = spx[i]
71
+ else:
72
+ sp = sp + spx[i]
73
+ sp = self.convs[i](sp)
74
+ sp = self.relu(self.bns[i](sp))
75
+ if i==0:
76
+ out = sp
77
+ else:
78
+ out = torch.cat((out,sp),1)
79
+
80
+ out = self.conv3(out)
81
+ out = self.bn3(out)
82
+
83
+ residual = self.shortcut(x)
84
+ out += residual
85
+ out = self.relu(out)
86
+
87
+ return out
88
+
89
+ class BasicBlockERes2Net_diff_AFF(nn.Module):
90
+ expansion = 4
91
+
92
+ def __init__(self, in_planes, planes, stride=1, baseWidth=24, scale=3):
93
+ super(BasicBlockERes2Net_diff_AFF, self).__init__()
94
+ width = int(math.floor(planes*(baseWidth/64.0)))
95
+ self.conv1 = nn.Conv2d(in_planes, width*scale, kernel_size=1, stride=stride, bias=False)
96
+ self.bn1 = nn.BatchNorm2d(width*scale)
97
+ self.nums = scale
98
+
99
+ convs=[]
100
+ fuse_models=[]
101
+ bns=[]
102
+ for i in range(self.nums):
103
+ convs.append(nn.Conv2d(width, width, kernel_size=3, padding=1, bias=False))
104
+ bns.append(nn.BatchNorm2d(width))
105
+ for j in range(self.nums - 1):
106
+ fuse_models.append(AFF(channels=width))
107
+
108
+ self.convs = nn.ModuleList(convs)
109
+ self.bns = nn.ModuleList(bns)
110
+ self.fuse_models = nn.ModuleList(fuse_models)
111
+ self.relu = ReLU(inplace=True)
112
+
113
+ self.conv3 = nn.Conv2d(width*scale, planes*self.expansion, kernel_size=1, bias=False)
114
+ self.bn3 = nn.BatchNorm2d(planes*self.expansion)
115
+ self.shortcut = nn.Sequential()
116
+ if stride != 1 or in_planes != self.expansion * planes:
117
+ self.shortcut = nn.Sequential(
118
+ nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
119
+ nn.BatchNorm2d(self.expansion * planes))
120
+ self.stride = stride
121
+ self.width = width
122
+ self.scale = scale
123
+
124
+ def forward(self, x):
125
+ residual = x
126
+
127
+ out = self.conv1(x)
128
+ out = self.bn1(out)
129
+ out = self.relu(out)
130
+ spx = torch.split(out,self.width,1)
131
+ for i in range(self.nums):
132
+ if i==0:
133
+ sp = spx[i]
134
+ else:
135
+ sp = self.fuse_models[i-1](sp, spx[i])
136
+
137
+ sp = self.convs[i](sp)
138
+ sp = self.relu(self.bns[i](sp))
139
+ if i==0:
140
+ out = sp
141
+ else:
142
+ out = torch.cat((out,sp),1)
143
+
144
+
145
+ out = self.conv3(out)
146
+ out = self.bn3(out)
147
+
148
+ residual = self.shortcut(x)
149
+ out += residual
150
+ out = self.relu(out)
151
+
152
+ return out
153
+
154
+ class ERes2Net(nn.Module):
155
+ def __init__(self,
156
+ block=BasicBlockERes2Net,
157
+ block_fuse=BasicBlockERes2Net_diff_AFF,
158
+ num_blocks=[3, 4, 6, 3],
159
+ m_channels=64,
160
+ feat_dim=80,
161
+ embedding_size=192,
162
+ pooling_func='TSTP',
163
+ two_emb_layer=False):
164
+ super(ERes2Net, self).__init__()
165
+ self.in_planes = m_channels
166
+ self.feat_dim = feat_dim
167
+ self.embedding_size = embedding_size
168
+ self.stats_dim = int(feat_dim / 8) * m_channels * 8
169
+ self.two_emb_layer = two_emb_layer
170
+
171
+ self.conv1 = nn.Conv2d(1, m_channels, kernel_size=3, stride=1, padding=1, bias=False)
172
+ self.bn1 = nn.BatchNorm2d(m_channels)
173
+
174
+ self.layer1 = self._make_layer(block, m_channels, num_blocks[0], stride=1)
175
+ self.layer2 = self._make_layer(block, m_channels * 2, num_blocks[1], stride=2)
176
+ self.layer3 = self._make_layer(block_fuse, m_channels * 4, num_blocks[2], stride=2)
177
+ self.layer4 = self._make_layer(block_fuse, m_channels * 8, num_blocks[3], stride=2)
178
+
179
+ self.layer1_downsample = nn.Conv2d(m_channels * 4, m_channels * 8, kernel_size=3, padding=1, stride=2, bias=False)
180
+ self.layer2_downsample = nn.Conv2d(m_channels * 8, m_channels * 16, kernel_size=3, padding=1, stride=2, bias=False)
181
+ self.layer3_downsample = nn.Conv2d(m_channels * 16, m_channels * 32, kernel_size=3, padding=1, stride=2, bias=False)
182
+
183
+ self.fuse_mode12 = AFF(channels=m_channels * 8)
184
+ self.fuse_mode123 = AFF(channels=m_channels * 16)
185
+ self.fuse_mode1234 = AFF(channels=m_channels * 32)
186
+
187
+ self.n_stats = 1 if pooling_func == 'TAP' or pooling_func == "TSDP" else 2
188
+ self.pool = getattr(pooling_layers, pooling_func)(
189
+ in_dim=self.stats_dim * block.expansion)
190
+ self.seg_1 = nn.Linear(self.stats_dim * block.expansion * self.n_stats, embedding_size)
191
+ if self.two_emb_layer:
192
+ self.seg_bn_1 = nn.BatchNorm1d(embedding_size, affine=False)
193
+ self.seg_2 = nn.Linear(embedding_size, embedding_size)
194
+ else:
195
+ self.seg_bn_1 = nn.Identity()
196
+ self.seg_2 = nn.Identity()
197
+
198
+ def _make_layer(self, block, planes, num_blocks, stride):
199
+ strides = [stride] + [1] * (num_blocks - 1)
200
+ layers = []
201
+ for stride in strides:
202
+ layers.append(block(self.in_planes, planes, stride))
203
+ self.in_planes = planes * block.expansion
204
+ return nn.Sequential(*layers)
205
+
206
+ def forward(self, x):
207
+ x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T)
208
+
209
+ x = x.unsqueeze_(1)
210
+ out = F.relu(self.bn1(self.conv1(x)))
211
+ out1 = self.layer1(out)
212
+ out2 = self.layer2(out1)
213
+ out1_downsample = self.layer1_downsample(out1)
214
+ fuse_out12 = self.fuse_mode12(out2, out1_downsample)
215
+ out3 = self.layer3(out2)
216
+ fuse_out12_downsample = self.layer2_downsample(fuse_out12)
217
+ fuse_out123 = self.fuse_mode123(out3, fuse_out12_downsample)
218
+ out4 = self.layer4(out3)
219
+ fuse_out123_downsample = self.layer3_downsample(fuse_out123)
220
+ fuse_out1234 = self.fuse_mode1234(out4, fuse_out123_downsample)
221
+ stats = self.pool(fuse_out1234)
222
+
223
+ embed_a = self.seg_1(stats)
224
+ if self.two_emb_layer:
225
+ out = F.relu(embed_a)
226
+ out = self.seg_bn_1(out)
227
+ embed_b = self.seg_2(out)
228
+ return embed_b
229
+ else:
230
+ return embed_a
231
+
232
+ def forward2(self, x,if_mean):
233
+ x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T)
234
+
235
+ x = x.unsqueeze_(1)
236
+ out = F.relu(self.bn1(self.conv1(x)))
237
+ out1 = self.layer1(out)
238
+ out2 = self.layer2(out1)
239
+ out1_downsample = self.layer1_downsample(out1)
240
+ fuse_out12 = self.fuse_mode12(out2, out1_downsample)
241
+ out3 = self.layer3(out2)
242
+ fuse_out12_downsample = self.layer2_downsample(fuse_out12)
243
+ fuse_out123 = self.fuse_mode123(out3, fuse_out12_downsample)
244
+ out4 = self.layer4(out3)
245
+ fuse_out123_downsample = self.layer3_downsample(fuse_out123)
246
+ fuse_out1234 = self.fuse_mode1234(out4, fuse_out123_downsample).flatten(start_dim=1,end_dim=2)#bs,20480,T
247
+ if(if_mean==False):
248
+ mean=fuse_out1234[0].transpose(1,0)#(T,20480),bs=T
249
+ else:
250
+ mean = fuse_out1234.mean(2)#bs,20480
251
+ mean_std=torch.cat([mean,torch.zeros_like(mean)],1)
252
+ return self.seg_1(mean_std)#(T,192)
253
+
254
+
255
+ # stats = self.pool(fuse_out1234)
256
+ # if self.two_emb_layer:
257
+ # out = F.relu(embed_a)
258
+ # out = self.seg_bn_1(out)
259
+ # embed_b = self.seg_2(out)
260
+ # return embed_b
261
+ # else:
262
+ # return embed_a
263
+
264
+ def forward3(self, x):
265
+ x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T)
266
+
267
+ x = x.unsqueeze_(1)
268
+ out = F.relu(self.bn1(self.conv1(x)))
269
+ out1 = self.layer1(out)
270
+ out2 = self.layer2(out1)
271
+ out1_downsample = self.layer1_downsample(out1)
272
+ fuse_out12 = self.fuse_mode12(out2, out1_downsample)
273
+ out3 = self.layer3(out2)
274
+ fuse_out12_downsample = self.layer2_downsample(fuse_out12)
275
+ fuse_out123 = self.fuse_mode123(out3, fuse_out12_downsample)
276
+ out4 = self.layer4(out3)
277
+ fuse_out123_downsample = self.layer3_downsample(fuse_out123)
278
+ fuse_out1234 = self.fuse_mode1234(out4, fuse_out123_downsample).flatten(start_dim=1,end_dim=2).mean(-1)
279
+ return fuse_out1234
280
+ # print(fuse_out1234.shape)
281
+ # print(fuse_out1234.flatten(start_dim=1,end_dim=2).shape)
282
+ # pdb.set_trace()
283
+
284
+
285
+
286
+
eres2net/fusion.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker). All Rights Reserved.
2
+ # Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+
8
+ class AFF(nn.Module):
9
+
10
+ def __init__(self, channels=64, r=4):
11
+ super(AFF, self).__init__()
12
+ inter_channels = int(channels // r)
13
+
14
+ self.local_att = nn.Sequential(
15
+ nn.Conv2d(channels * 2, inter_channels, kernel_size=1, stride=1, padding=0),
16
+ nn.BatchNorm2d(inter_channels),
17
+ nn.SiLU(inplace=True),
18
+ nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
19
+ nn.BatchNorm2d(channels),
20
+ )
21
+
22
+ def forward(self, x, ds_y):
23
+ xa = torch.cat((x, ds_y), dim=1)
24
+ x_att = self.local_att(xa)
25
+ x_att = 1.0 + torch.tanh(x_att)
26
+ xo = torch.mul(x, x_att) + torch.mul(ds_y, 2.0-x_att)
27
+
28
+ return xo
29
+
eres2net/kaldi.py ADDED
@@ -0,0 +1,819 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Tuple
3
+
4
+ import torch
5
+ import torchaudio
6
+ from torch import Tensor
7
+
8
+ __all__ = [
9
+ "get_mel_banks",
10
+ "inverse_mel_scale",
11
+ "inverse_mel_scale_scalar",
12
+ "mel_scale",
13
+ "mel_scale_scalar",
14
+ "spectrogram",
15
+ "fbank",
16
+ "mfcc",
17
+ "vtln_warp_freq",
18
+ "vtln_warp_mel_freq",
19
+ ]
20
+
21
+ # numeric_limits<float>::epsilon() 1.1920928955078125e-07
22
+ EPSILON = torch.tensor(torch.finfo(torch.float).eps)
23
+ # 1 milliseconds = 0.001 seconds
24
+ MILLISECONDS_TO_SECONDS = 0.001
25
+
26
+ # window types
27
+ HAMMING = "hamming"
28
+ HANNING = "hanning"
29
+ POVEY = "povey"
30
+ RECTANGULAR = "rectangular"
31
+ BLACKMAN = "blackman"
32
+ WINDOWS = [HAMMING, HANNING, POVEY, RECTANGULAR, BLACKMAN]
33
+
34
+
35
+ def _get_epsilon(device, dtype):
36
+ return EPSILON.to(device=device, dtype=dtype)
37
+
38
+
39
+ def _next_power_of_2(x: int) -> int:
40
+ r"""Returns the smallest power of 2 that is greater than x"""
41
+ return 1 if x == 0 else 2 ** (x - 1).bit_length()
42
+
43
+
44
+ def _get_strided(waveform: Tensor, window_size: int, window_shift: int, snip_edges: bool) -> Tensor:
45
+ r"""Given a waveform (1D tensor of size ``num_samples``), it returns a 2D tensor (m, ``window_size``)
46
+ representing how the window is shifted along the waveform. Each row is a frame.
47
+
48
+ Args:
49
+ waveform (Tensor): Tensor of size ``num_samples``
50
+ window_size (int): Frame length
51
+ window_shift (int): Frame shift
52
+ snip_edges (bool): If True, end effects will be handled by outputting only frames that completely fit
53
+ in the file, and the number of frames depends on the frame_length. If False, the number of frames
54
+ depends only on the frame_shift, and we reflect the data at the ends.
55
+
56
+ Returns:
57
+ Tensor: 2D tensor of size (m, ``window_size``) where each row is a frame
58
+ """
59
+ assert waveform.dim() == 1
60
+ num_samples = waveform.size(0)
61
+ strides = (window_shift * waveform.stride(0), waveform.stride(0))
62
+
63
+ if snip_edges:
64
+ if num_samples < window_size:
65
+ return torch.empty((0, 0), dtype=waveform.dtype, device=waveform.device)
66
+ else:
67
+ m = 1 + (num_samples - window_size) // window_shift
68
+ else:
69
+ reversed_waveform = torch.flip(waveform, [0])
70
+ m = (num_samples + (window_shift // 2)) // window_shift
71
+ pad = window_size // 2 - window_shift // 2
72
+ pad_right = reversed_waveform
73
+ if pad > 0:
74
+ # torch.nn.functional.pad returns [2,1,0,1,2] for 'reflect'
75
+ # but we want [2, 1, 0, 0, 1, 2]
76
+ pad_left = reversed_waveform[-pad:]
77
+ waveform = torch.cat((pad_left, waveform, pad_right), dim=0)
78
+ else:
79
+ # pad is negative so we want to trim the waveform at the front
80
+ waveform = torch.cat((waveform[-pad:], pad_right), dim=0)
81
+
82
+ sizes = (m, window_size)
83
+ return waveform.as_strided(sizes, strides)
84
+
85
+
86
+ def _feature_window_function(
87
+ window_type: str,
88
+ window_size: int,
89
+ blackman_coeff: float,
90
+ device: torch.device,
91
+ dtype: int,
92
+ ) -> Tensor:
93
+ r"""Returns a window function with the given type and size"""
94
+ if window_type == HANNING:
95
+ return torch.hann_window(window_size, periodic=False, device=device, dtype=dtype)
96
+ elif window_type == HAMMING:
97
+ return torch.hamming_window(window_size, periodic=False, alpha=0.54, beta=0.46, device=device, dtype=dtype)
98
+ elif window_type == POVEY:
99
+ # like hanning but goes to zero at edges
100
+ return torch.hann_window(window_size, periodic=False, device=device, dtype=dtype).pow(0.85)
101
+ elif window_type == RECTANGULAR:
102
+ return torch.ones(window_size, device=device, dtype=dtype)
103
+ elif window_type == BLACKMAN:
104
+ a = 2 * math.pi / (window_size - 1)
105
+ window_function = torch.arange(window_size, device=device, dtype=dtype)
106
+ # can't use torch.blackman_window as they use different coefficients
107
+ return (
108
+ blackman_coeff
109
+ - 0.5 * torch.cos(a * window_function)
110
+ + (0.5 - blackman_coeff) * torch.cos(2 * a * window_function)
111
+ ).to(device=device, dtype=dtype)
112
+ else:
113
+ raise Exception("Invalid window type " + window_type)
114
+
115
+
116
+ def _get_log_energy(strided_input: Tensor, epsilon: Tensor, energy_floor: float) -> Tensor:
117
+ r"""Returns the log energy of size (m) for a strided_input (m,*)"""
118
+ device, dtype = strided_input.device, strided_input.dtype
119
+ log_energy = torch.max(strided_input.pow(2).sum(1), epsilon).log() # size (m)
120
+ if energy_floor == 0.0:
121
+ return log_energy
122
+ return torch.max(log_energy, torch.tensor(math.log(energy_floor), device=device, dtype=dtype))
123
+
124
+
125
+ def _get_waveform_and_window_properties(
126
+ waveform: Tensor,
127
+ channel: int,
128
+ sample_frequency: float,
129
+ frame_shift: float,
130
+ frame_length: float,
131
+ round_to_power_of_two: bool,
132
+ preemphasis_coefficient: float,
133
+ ) -> Tuple[Tensor, int, int, int]:
134
+ r"""Gets the waveform and window properties"""
135
+ channel = max(channel, 0)
136
+ assert channel < waveform.size(0), "Invalid channel {} for size {}".format(channel, waveform.size(0))
137
+ waveform = waveform[channel, :] # size (n)
138
+ window_shift = int(sample_frequency * frame_shift * MILLISECONDS_TO_SECONDS)
139
+ window_size = int(sample_frequency * frame_length * MILLISECONDS_TO_SECONDS)
140
+ padded_window_size = _next_power_of_2(window_size) if round_to_power_of_two else window_size
141
+
142
+ assert 2 <= window_size <= len(waveform), "choose a window size {} that is [2, {}]".format(
143
+ window_size, len(waveform)
144
+ )
145
+ assert 0 < window_shift, "`window_shift` must be greater than 0"
146
+ assert padded_window_size % 2 == 0, (
147
+ "the padded `window_size` must be divisible by two." " use `round_to_power_of_two` or change `frame_length`"
148
+ )
149
+ assert 0.0 <= preemphasis_coefficient <= 1.0, "`preemphasis_coefficient` must be between [0,1]"
150
+ assert sample_frequency > 0, "`sample_frequency` must be greater than zero"
151
+ return waveform, window_shift, window_size, padded_window_size
152
+
153
+
154
+ def _get_window(
155
+ waveform: Tensor,
156
+ padded_window_size: int,
157
+ window_size: int,
158
+ window_shift: int,
159
+ window_type: str,
160
+ blackman_coeff: float,
161
+ snip_edges: bool,
162
+ raw_energy: bool,
163
+ energy_floor: float,
164
+ dither: float,
165
+ remove_dc_offset: bool,
166
+ preemphasis_coefficient: float,
167
+ ) -> Tuple[Tensor, Tensor]:
168
+ r"""Gets a window and its log energy
169
+
170
+ Returns:
171
+ (Tensor, Tensor): strided_input of size (m, ``padded_window_size``) and signal_log_energy of size (m)
172
+ """
173
+ device, dtype = waveform.device, waveform.dtype
174
+ epsilon = _get_epsilon(device, dtype)
175
+
176
+ # size (m, window_size)
177
+ strided_input = _get_strided(waveform, window_size, window_shift, snip_edges)
178
+
179
+ if dither != 0.0:
180
+ rand_gauss = torch.randn(strided_input.shape, device=device, dtype=dtype)
181
+ strided_input = strided_input + rand_gauss * dither
182
+
183
+ if remove_dc_offset:
184
+ # Subtract each row/frame by its mean
185
+ row_means = torch.mean(strided_input, dim=1).unsqueeze(1) # size (m, 1)
186
+ strided_input = strided_input - row_means
187
+
188
+ if raw_energy:
189
+ # Compute the log energy of each row/frame before applying preemphasis and
190
+ # window function
191
+ signal_log_energy = _get_log_energy(strided_input, epsilon, energy_floor) # size (m)
192
+
193
+ if preemphasis_coefficient != 0.0:
194
+ # strided_input[i,j] -= preemphasis_coefficient * strided_input[i, max(0, j-1)] for all i,j
195
+ offset_strided_input = torch.nn.functional.pad(strided_input.unsqueeze(0), (1, 0), mode="replicate").squeeze(
196
+ 0
197
+ ) # size (m, window_size + 1)
198
+ strided_input = strided_input - preemphasis_coefficient * offset_strided_input[:, :-1]
199
+
200
+ # Apply window_function to each row/frame
201
+ window_function = _feature_window_function(window_type, window_size, blackman_coeff, device, dtype).unsqueeze(
202
+ 0
203
+ ) # size (1, window_size)
204
+ strided_input = strided_input * window_function # size (m, window_size)
205
+
206
+ # Pad columns with zero until we reach size (m, padded_window_size)
207
+ if padded_window_size != window_size:
208
+ padding_right = padded_window_size - window_size
209
+ strided_input = torch.nn.functional.pad(
210
+ strided_input.unsqueeze(0), (0, padding_right), mode="constant", value=0
211
+ ).squeeze(0)
212
+
213
+ # Compute energy after window function (not the raw one)
214
+ if not raw_energy:
215
+ signal_log_energy = _get_log_energy(strided_input, epsilon, energy_floor) # size (m)
216
+
217
+ return strided_input, signal_log_energy
218
+
219
+
220
+ def _subtract_column_mean(tensor: Tensor, subtract_mean: bool) -> Tensor:
221
+ # subtracts the column mean of the tensor size (m, n) if subtract_mean=True
222
+ # it returns size (m, n)
223
+ if subtract_mean:
224
+ col_means = torch.mean(tensor, dim=0).unsqueeze(0)
225
+ tensor = tensor - col_means
226
+ return tensor
227
+
228
+
229
+ def spectrogram(
230
+ waveform: Tensor,
231
+ blackman_coeff: float = 0.42,
232
+ channel: int = -1,
233
+ dither: float = 0.0,
234
+ energy_floor: float = 1.0,
235
+ frame_length: float = 25.0,
236
+ frame_shift: float = 10.0,
237
+ min_duration: float = 0.0,
238
+ preemphasis_coefficient: float = 0.97,
239
+ raw_energy: bool = True,
240
+ remove_dc_offset: bool = True,
241
+ round_to_power_of_two: bool = True,
242
+ sample_frequency: float = 16000.0,
243
+ snip_edges: bool = True,
244
+ subtract_mean: bool = False,
245
+ window_type: str = POVEY,
246
+ ) -> Tensor:
247
+ r"""Create a spectrogram from a raw audio signal. This matches the input/output of Kaldi's
248
+ compute-spectrogram-feats.
249
+
250
+ Args:
251
+ waveform (Tensor): Tensor of audio of size (c, n) where c is in the range [0,2)
252
+ blackman_coeff (float, optional): Constant coefficient for generalized Blackman window. (Default: ``0.42``)
253
+ channel (int, optional): Channel to extract (-1 -> expect mono, 0 -> left, 1 -> right) (Default: ``-1``)
254
+ dither (float, optional): Dithering constant (0.0 means no dither). If you turn this off, you should set
255
+ the energy_floor option, e.g. to 1.0 or 0.1 (Default: ``0.0``)
256
+ energy_floor (float, optional): Floor on energy (absolute, not relative) in Spectrogram computation. Caution:
257
+ this floor is applied to the zeroth component, representing the total signal energy. The floor on the
258
+ individual spectrogram elements is fixed at std::numeric_limits<float>::epsilon(). (Default: ``1.0``)
259
+ frame_length (float, optional): Frame length in milliseconds (Default: ``25.0``)
260
+ frame_shift (float, optional): Frame shift in milliseconds (Default: ``10.0``)
261
+ min_duration (float, optional): Minimum duration of segments to process (in seconds). (Default: ``0.0``)
262
+ preemphasis_coefficient (float, optional): Coefficient for use in signal preemphasis (Default: ``0.97``)
263
+ raw_energy (bool, optional): If True, compute energy before preemphasis and windowing (Default: ``True``)
264
+ remove_dc_offset (bool, optional): Subtract mean from waveform on each frame (Default: ``True``)
265
+ round_to_power_of_two (bool, optional): If True, round window size to power of two by zero-padding input
266
+ to FFT. (Default: ``True``)
267
+ sample_frequency (float, optional): Waveform data sample frequency (must match the waveform file, if
268
+ specified there) (Default: ``16000.0``)
269
+ snip_edges (bool, optional): If True, end effects will be handled by outputting only frames that completely fit
270
+ in the file, and the number of frames depends on the frame_length. If False, the number of frames
271
+ depends only on the frame_shift, and we reflect the data at the ends. (Default: ``True``)
272
+ subtract_mean (bool, optional): Subtract mean of each feature file [CMS]; not recommended to do
273
+ it this way. (Default: ``False``)
274
+ window_type (str, optional): Type of window ('hamming'|'hanning'|'povey'|'rectangular'|'blackman')
275
+ (Default: ``'povey'``)
276
+
277
+ Returns:
278
+ Tensor: A spectrogram identical to what Kaldi would output. The shape is
279
+ (m, ``padded_window_size // 2 + 1``) where m is calculated in _get_strided
280
+ """
281
+ device, dtype = waveform.device, waveform.dtype
282
+ epsilon = _get_epsilon(device, dtype)
283
+
284
+ waveform, window_shift, window_size, padded_window_size = _get_waveform_and_window_properties(
285
+ waveform, channel, sample_frequency, frame_shift, frame_length, round_to_power_of_two, preemphasis_coefficient
286
+ )
287
+
288
+ if len(waveform) < min_duration * sample_frequency:
289
+ # signal is too short
290
+ return torch.empty(0)
291
+
292
+ strided_input, signal_log_energy = _get_window(
293
+ waveform,
294
+ padded_window_size,
295
+ window_size,
296
+ window_shift,
297
+ window_type,
298
+ blackman_coeff,
299
+ snip_edges,
300
+ raw_energy,
301
+ energy_floor,
302
+ dither,
303
+ remove_dc_offset,
304
+ preemphasis_coefficient,
305
+ )
306
+
307
+ # size (m, padded_window_size // 2 + 1, 2)
308
+ fft = torch.fft.rfft(strided_input)
309
+
310
+ # Convert the FFT into a power spectrum
311
+ power_spectrum = torch.max(fft.abs().pow(2.0), epsilon).log() # size (m, padded_window_size // 2 + 1)
312
+ power_spectrum[:, 0] = signal_log_energy
313
+
314
+ power_spectrum = _subtract_column_mean(power_spectrum, subtract_mean)
315
+ return power_spectrum
316
+
317
+
318
+ def inverse_mel_scale_scalar(mel_freq: float) -> float:
319
+ return 700.0 * (math.exp(mel_freq / 1127.0) - 1.0)
320
+
321
+
322
+ def inverse_mel_scale(mel_freq: Tensor) -> Tensor:
323
+ return 700.0 * ((mel_freq / 1127.0).exp() - 1.0)
324
+
325
+
326
+ def mel_scale_scalar(freq: float) -> float:
327
+ return 1127.0 * math.log(1.0 + freq / 700.0)
328
+
329
+
330
+ def mel_scale(freq: Tensor) -> Tensor:
331
+ return 1127.0 * (1.0 + freq / 700.0).log()
332
+
333
+
334
+ def vtln_warp_freq(
335
+ vtln_low_cutoff: float,
336
+ vtln_high_cutoff: float,
337
+ low_freq: float,
338
+ high_freq: float,
339
+ vtln_warp_factor: float,
340
+ freq: Tensor,
341
+ ) -> Tensor:
342
+ r"""This computes a VTLN warping function that is not the same as HTK's one,
343
+ but has similar inputs (this function has the advantage of never producing
344
+ empty bins).
345
+
346
+ This function computes a warp function F(freq), defined between low_freq
347
+ and high_freq inclusive, with the following properties:
348
+ F(low_freq) == low_freq
349
+ F(high_freq) == high_freq
350
+ The function is continuous and piecewise linear with two inflection
351
+ points.
352
+ The lower inflection point (measured in terms of the unwarped
353
+ frequency) is at frequency l, determined as described below.
354
+ The higher inflection point is at a frequency h, determined as
355
+ described below.
356
+ If l <= f <= h, then F(f) = f/vtln_warp_factor.
357
+ If the higher inflection point (measured in terms of the unwarped
358
+ frequency) is at h, then max(h, F(h)) == vtln_high_cutoff.
359
+ Since (by the last point) F(h) == h/vtln_warp_factor, then
360
+ max(h, h/vtln_warp_factor) == vtln_high_cutoff, so
361
+ h = vtln_high_cutoff / max(1, 1/vtln_warp_factor).
362
+ = vtln_high_cutoff * min(1, vtln_warp_factor).
363
+ If the lower inflection point (measured in terms of the unwarped
364
+ frequency) is at l, then min(l, F(l)) == vtln_low_cutoff
365
+ This implies that l = vtln_low_cutoff / min(1, 1/vtln_warp_factor)
366
+ = vtln_low_cutoff * max(1, vtln_warp_factor)
367
+ Args:
368
+ vtln_low_cutoff (float): Lower frequency cutoffs for VTLN
369
+ vtln_high_cutoff (float): Upper frequency cutoffs for VTLN
370
+ low_freq (float): Lower frequency cutoffs in mel computation
371
+ high_freq (float): Upper frequency cutoffs in mel computation
372
+ vtln_warp_factor (float): Vtln warp factor
373
+ freq (Tensor): given frequency in Hz
374
+
375
+ Returns:
376
+ Tensor: Freq after vtln warp
377
+ """
378
+ assert vtln_low_cutoff > low_freq, "be sure to set the vtln_low option higher than low_freq"
379
+ assert vtln_high_cutoff < high_freq, "be sure to set the vtln_high option lower than high_freq [or negative]"
380
+ l = vtln_low_cutoff * max(1.0, vtln_warp_factor)
381
+ h = vtln_high_cutoff * min(1.0, vtln_warp_factor)
382
+ scale = 1.0 / vtln_warp_factor
383
+ Fl = scale * l # F(l)
384
+ Fh = scale * h # F(h)
385
+ assert l > low_freq and h < high_freq
386
+ # slope of left part of the 3-piece linear function
387
+ scale_left = (Fl - low_freq) / (l - low_freq)
388
+ # [slope of center part is just "scale"]
389
+
390
+ # slope of right part of the 3-piece linear function
391
+ scale_right = (high_freq - Fh) / (high_freq - h)
392
+
393
+ res = torch.empty_like(freq)
394
+
395
+ outside_low_high_freq = torch.lt(freq, low_freq) | torch.gt(freq, high_freq) # freq < low_freq || freq > high_freq
396
+ before_l = torch.lt(freq, l) # freq < l
397
+ before_h = torch.lt(freq, h) # freq < h
398
+ after_h = torch.ge(freq, h) # freq >= h
399
+
400
+ # order of operations matter here (since there is overlapping frequency regions)
401
+ res[after_h] = high_freq + scale_right * (freq[after_h] - high_freq)
402
+ res[before_h] = scale * freq[before_h]
403
+ res[before_l] = low_freq + scale_left * (freq[before_l] - low_freq)
404
+ res[outside_low_high_freq] = freq[outside_low_high_freq]
405
+
406
+ return res
407
+
408
+
409
+ def vtln_warp_mel_freq(
410
+ vtln_low_cutoff: float,
411
+ vtln_high_cutoff: float,
412
+ low_freq,
413
+ high_freq: float,
414
+ vtln_warp_factor: float,
415
+ mel_freq: Tensor,
416
+ ) -> Tensor:
417
+ r"""
418
+ Args:
419
+ vtln_low_cutoff (float): Lower frequency cutoffs for VTLN
420
+ vtln_high_cutoff (float): Upper frequency cutoffs for VTLN
421
+ low_freq (float): Lower frequency cutoffs in mel computation
422
+ high_freq (float): Upper frequency cutoffs in mel computation
423
+ vtln_warp_factor (float): Vtln warp factor
424
+ mel_freq (Tensor): Given frequency in Mel
425
+
426
+ Returns:
427
+ Tensor: ``mel_freq`` after vtln warp
428
+ """
429
+ return mel_scale(
430
+ vtln_warp_freq(
431
+ vtln_low_cutoff, vtln_high_cutoff, low_freq, high_freq, vtln_warp_factor, inverse_mel_scale(mel_freq)
432
+ )
433
+ )
434
+
435
+
436
+ def get_mel_banks(
437
+ num_bins: int,
438
+ window_length_padded: int,
439
+ sample_freq: float,
440
+ low_freq: float,
441
+ high_freq: float,
442
+ vtln_low: float,
443
+ vtln_high: float,
444
+ vtln_warp_factor: float,device=None,dtype=None
445
+ ) -> Tuple[Tensor, Tensor]:
446
+ """
447
+ Returns:
448
+ (Tensor, Tensor): The tuple consists of ``bins`` (which is
449
+ melbank of size (``num_bins``, ``num_fft_bins``)) and ``center_freqs`` (which is
450
+ center frequencies of bins of size (``num_bins``)).
451
+ """
452
+ assert num_bins > 3, "Must have at least 3 mel bins"
453
+ assert window_length_padded % 2 == 0
454
+ num_fft_bins = window_length_padded / 2
455
+ nyquist = 0.5 * sample_freq
456
+
457
+ if high_freq <= 0.0:
458
+ high_freq += nyquist
459
+
460
+ assert (
461
+ (0.0 <= low_freq < nyquist) and (0.0 < high_freq <= nyquist) and (low_freq < high_freq)
462
+ ), "Bad values in options: low-freq {} and high-freq {} vs. nyquist {}".format(low_freq, high_freq, nyquist)
463
+
464
+ # fft-bin width [think of it as Nyquist-freq / half-window-length]
465
+ fft_bin_width = sample_freq / window_length_padded
466
+ mel_low_freq = mel_scale_scalar(low_freq)
467
+ mel_high_freq = mel_scale_scalar(high_freq)
468
+
469
+ # divide by num_bins+1 in next line because of end-effects where the bins
470
+ # spread out to the sides.
471
+ mel_freq_delta = (mel_high_freq - mel_low_freq) / (num_bins + 1)
472
+
473
+ if vtln_high < 0.0:
474
+ vtln_high += nyquist
475
+
476
+ assert vtln_warp_factor == 1.0 or (
477
+ (low_freq < vtln_low < high_freq) and (0.0 < vtln_high < high_freq) and (vtln_low < vtln_high)
478
+ ), "Bad values in options: vtln-low {} and vtln-high {}, versus " "low-freq {} and high-freq {}".format(
479
+ vtln_low, vtln_high, low_freq, high_freq
480
+ )
481
+
482
+ bin = torch.arange(num_bins).unsqueeze(1)
483
+ left_mel = mel_low_freq + bin * mel_freq_delta # size(num_bins, 1)
484
+ center_mel = mel_low_freq + (bin + 1.0) * mel_freq_delta # size(num_bins, 1)
485
+ right_mel = mel_low_freq + (bin + 2.0) * mel_freq_delta # size(num_bins, 1)
486
+
487
+ if vtln_warp_factor != 1.0:
488
+ left_mel = vtln_warp_mel_freq(vtln_low, vtln_high, low_freq, high_freq, vtln_warp_factor, left_mel)
489
+ center_mel = vtln_warp_mel_freq(vtln_low, vtln_high, low_freq, high_freq, vtln_warp_factor, center_mel)
490
+ right_mel = vtln_warp_mel_freq(vtln_low, vtln_high, low_freq, high_freq, vtln_warp_factor, right_mel)
491
+
492
+ # center_freqs = inverse_mel_scale(center_mel) # size (num_bins)
493
+ # size(1, num_fft_bins)
494
+ mel = mel_scale(fft_bin_width * torch.arange(num_fft_bins)).unsqueeze(0)
495
+
496
+ # size (num_bins, num_fft_bins)
497
+ up_slope = (mel - left_mel) / (center_mel - left_mel)
498
+ down_slope = (right_mel - mel) / (right_mel - center_mel)
499
+
500
+ if vtln_warp_factor == 1.0:
501
+ # left_mel < center_mel < right_mel so we can min the two slopes and clamp negative values
502
+ bins = torch.max(torch.zeros(1), torch.min(up_slope, down_slope))
503
+ else:
504
+ # warping can move the order of left_mel, center_mel, right_mel anywhere
505
+ bins = torch.zeros_like(up_slope)
506
+ up_idx = torch.gt(mel, left_mel) & torch.le(mel, center_mel) # left_mel < mel <= center_mel
507
+ down_idx = torch.gt(mel, center_mel) & torch.lt(mel, right_mel) # center_mel < mel < right_mel
508
+ bins[up_idx] = up_slope[up_idx]
509
+ bins[down_idx] = down_slope[down_idx]
510
+
511
+ return bins.to(device=device,dtype=dtype)#, center_freqs
512
+
513
+ cache={}
514
+ def fbank(
515
+ waveform: Tensor,
516
+ blackman_coeff: float = 0.42,
517
+ channel: int = -1,
518
+ dither: float = 0.0,
519
+ energy_floor: float = 1.0,
520
+ frame_length: float = 25.0,
521
+ frame_shift: float = 10.0,
522
+ high_freq: float = 0.0,
523
+ htk_compat: bool = False,
524
+ low_freq: float = 20.0,
525
+ min_duration: float = 0.0,
526
+ num_mel_bins: int = 23,
527
+ preemphasis_coefficient: float = 0.97,
528
+ raw_energy: bool = True,
529
+ remove_dc_offset: bool = True,
530
+ round_to_power_of_two: bool = True,
531
+ sample_frequency: float = 16000.0,
532
+ snip_edges: bool = True,
533
+ subtract_mean: bool = False,
534
+ use_energy: bool = False,
535
+ use_log_fbank: bool = True,
536
+ use_power: bool = True,
537
+ vtln_high: float = -500.0,
538
+ vtln_low: float = 100.0,
539
+ vtln_warp: float = 1.0,
540
+ window_type: str = POVEY,
541
+ ) -> Tensor:
542
+ r"""Create a fbank from a raw audio signal. This matches the input/output of Kaldi's
543
+ compute-fbank-feats.
544
+
545
+ Args:
546
+ waveform (Tensor): Tensor of audio of size (c, n) where c is in the range [0,2)
547
+ blackman_coeff (float, optional): Constant coefficient for generalized Blackman window. (Default: ``0.42``)
548
+ channel (int, optional): Channel to extract (-1 -> expect mono, 0 -> left, 1 -> right) (Default: ``-1``)
549
+ dither (float, optional): Dithering constant (0.0 means no dither). If you turn this off, you should set
550
+ the energy_floor option, e.g. to 1.0 or 0.1 (Default: ``0.0``)
551
+ energy_floor (float, optional): Floor on energy (absolute, not relative) in Spectrogram computation. Caution:
552
+ this floor is applied to the zeroth component, representing the total signal energy. The floor on the
553
+ individual spectrogram elements is fixed at std::numeric_limits<float>::epsilon(). (Default: ``1.0``)
554
+ frame_length (float, optional): Frame length in milliseconds (Default: ``25.0``)
555
+ frame_shift (float, optional): Frame shift in milliseconds (Default: ``10.0``)
556
+ high_freq (float, optional): High cutoff frequency for mel bins (if <= 0, offset from Nyquist)
557
+ (Default: ``0.0``)
558
+ htk_compat (bool, optional): If true, put energy last. Warning: not sufficient to get HTK compatible features
559
+ (need to change other parameters). (Default: ``False``)
560
+ low_freq (float, optional): Low cutoff frequency for mel bins (Default: ``20.0``)
561
+ min_duration (float, optional): Minimum duration of segments to process (in seconds). (Default: ``0.0``)
562
+ num_mel_bins (int, optional): Number of triangular mel-frequency bins (Default: ``23``)
563
+ preemphasis_coefficient (float, optional): Coefficient for use in signal preemphasis (Default: ``0.97``)
564
+ raw_energy (bool, optional): If True, compute energy before preemphasis and windowing (Default: ``True``)
565
+ remove_dc_offset (bool, optional): Subtract mean from waveform on each frame (Default: ``True``)
566
+ round_to_power_of_two (bool, optional): If True, round window size to power of two by zero-padding input
567
+ to FFT. (Default: ``True``)
568
+ sample_frequency (float, optional): Waveform data sample frequency (must match the waveform file, if
569
+ specified there) (Default: ``16000.0``)
570
+ snip_edges (bool, optional): If True, end effects will be handled by outputting only frames that completely fit
571
+ in the file, and the number of frames depends on the frame_length. If False, the number of frames
572
+ depends only on the frame_shift, and we reflect the data at the ends. (Default: ``True``)
573
+ subtract_mean (bool, optional): Subtract mean of each feature file [CMS]; not recommended to do
574
+ it this way. (Default: ``False``)
575
+ use_energy (bool, optional): Add an extra dimension with energy to the FBANK output. (Default: ``False``)
576
+ use_log_fbank (bool, optional):If true, produce log-filterbank, else produce linear. (Default: ``True``)
577
+ use_power (bool, optional): If true, use power, else use magnitude. (Default: ``True``)
578
+ vtln_high (float, optional): High inflection point in piecewise linear VTLN warping function (if
579
+ negative, offset from high-mel-freq (Default: ``-500.0``)
580
+ vtln_low (float, optional): Low inflection point in piecewise linear VTLN warping function (Default: ``100.0``)
581
+ vtln_warp (float, optional): Vtln warp factor (only applicable if vtln_map not specified) (Default: ``1.0``)
582
+ window_type (str, optional): Type of window ('hamming'|'hanning'|'povey'|'rectangular'|'blackman')
583
+ (Default: ``'povey'``)
584
+
585
+ Returns:
586
+ Tensor: A fbank identical to what Kaldi would output. The shape is (m, ``num_mel_bins + use_energy``)
587
+ where m is calculated in _get_strided
588
+ """
589
+ device, dtype = waveform.device, waveform.dtype
590
+
591
+ waveform, window_shift, window_size, padded_window_size = _get_waveform_and_window_properties(
592
+ waveform, channel, sample_frequency, frame_shift, frame_length, round_to_power_of_two, preemphasis_coefficient
593
+ )
594
+
595
+ if len(waveform) < min_duration * sample_frequency:
596
+ # signal is too short
597
+ return torch.empty(0, device=device, dtype=dtype)
598
+
599
+ # strided_input, size (m, padded_window_size) and signal_log_energy, size (m)
600
+ strided_input, signal_log_energy = _get_window(
601
+ waveform,
602
+ padded_window_size,
603
+ window_size,
604
+ window_shift,
605
+ window_type,
606
+ blackman_coeff,
607
+ snip_edges,
608
+ raw_energy,
609
+ energy_floor,
610
+ dither,
611
+ remove_dc_offset,
612
+ preemphasis_coefficient,
613
+ )
614
+
615
+ # size (m, padded_window_size // 2 + 1)
616
+ spectrum = torch.fft.rfft(strided_input).abs()
617
+ if use_power:
618
+ spectrum = spectrum.pow(2.0)
619
+
620
+ # size (num_mel_bins, padded_window_size // 2)
621
+ # print(num_mel_bins, padded_window_size, sample_frequency, low_freq, high_freq, vtln_low, vtln_high, vtln_warp)
622
+
623
+ cache_key="%s-%s-%s-%s-%s-%s-%s-%s-%s-%s"%(num_mel_bins, padded_window_size, sample_frequency, low_freq, high_freq, vtln_low, vtln_high, vtln_warp,device,dtype)
624
+ if cache_key not in cache:
625
+ mel_energies = get_mel_banks(
626
+ num_mel_bins, padded_window_size, sample_frequency, low_freq, high_freq, vtln_low, vtln_high, vtln_warp,device,dtype
627
+ )
628
+ cache[cache_key]=mel_energies
629
+ else:
630
+ mel_energies=cache[cache_key]
631
+
632
+ # pad right column with zeros and add dimension, size (num_mel_bins, padded_window_size // 2 + 1)
633
+ mel_energies = torch.nn.functional.pad(mel_energies, (0, 1), mode="constant", value=0)
634
+
635
+ # sum with mel fiterbanks over the power spectrum, size (m, num_mel_bins)
636
+ mel_energies = torch.mm(spectrum, mel_energies.T)
637
+ if use_log_fbank:
638
+ # avoid log of zero (which should be prevented anyway by dithering)
639
+ mel_energies = torch.max(mel_energies, _get_epsilon(device, dtype)).log()
640
+
641
+ # if use_energy then add it as the last column for htk_compat == true else first column
642
+ if use_energy:
643
+ signal_log_energy = signal_log_energy.unsqueeze(1) # size (m, 1)
644
+ # returns size (m, num_mel_bins + 1)
645
+ if htk_compat:
646
+ mel_energies = torch.cat((mel_energies, signal_log_energy), dim=1)
647
+ else:
648
+ mel_energies = torch.cat((signal_log_energy, mel_energies), dim=1)
649
+
650
+ mel_energies = _subtract_column_mean(mel_energies, subtract_mean)
651
+ return mel_energies
652
+
653
+
654
+ def _get_dct_matrix(num_ceps: int, num_mel_bins: int) -> Tensor:
655
+ # returns a dct matrix of size (num_mel_bins, num_ceps)
656
+ # size (num_mel_bins, num_mel_bins)
657
+ dct_matrix = torchaudio.functional.create_dct(num_mel_bins, num_mel_bins, "ortho")
658
+ # kaldi expects the first cepstral to be weighted sum of factor sqrt(1/num_mel_bins)
659
+ # this would be the first column in the dct_matrix for torchaudio as it expects a
660
+ # right multiply (which would be the first column of the kaldi's dct_matrix as kaldi
661
+ # expects a left multiply e.g. dct_matrix * vector).
662
+ dct_matrix[:, 0] = math.sqrt(1 / float(num_mel_bins))
663
+ dct_matrix = dct_matrix[:, :num_ceps]
664
+ return dct_matrix
665
+
666
+
667
+ def _get_lifter_coeffs(num_ceps: int, cepstral_lifter: float) -> Tensor:
668
+ # returns size (num_ceps)
669
+ # Compute liftering coefficients (scaling on cepstral coeffs)
670
+ # coeffs are numbered slightly differently from HTK: the zeroth index is C0, which is not affected.
671
+ i = torch.arange(num_ceps)
672
+ return 1.0 + 0.5 * cepstral_lifter * torch.sin(math.pi * i / cepstral_lifter)
673
+
674
+
675
+ def mfcc(
676
+ waveform: Tensor,
677
+ blackman_coeff: float = 0.42,
678
+ cepstral_lifter: float = 22.0,
679
+ channel: int = -1,
680
+ dither: float = 0.0,
681
+ energy_floor: float = 1.0,
682
+ frame_length: float = 25.0,
683
+ frame_shift: float = 10.0,
684
+ high_freq: float = 0.0,
685
+ htk_compat: bool = False,
686
+ low_freq: float = 20.0,
687
+ num_ceps: int = 13,
688
+ min_duration: float = 0.0,
689
+ num_mel_bins: int = 23,
690
+ preemphasis_coefficient: float = 0.97,
691
+ raw_energy: bool = True,
692
+ remove_dc_offset: bool = True,
693
+ round_to_power_of_two: bool = True,
694
+ sample_frequency: float = 16000.0,
695
+ snip_edges: bool = True,
696
+ subtract_mean: bool = False,
697
+ use_energy: bool = False,
698
+ vtln_high: float = -500.0,
699
+ vtln_low: float = 100.0,
700
+ vtln_warp: float = 1.0,
701
+ window_type: str = POVEY,
702
+ ) -> Tensor:
703
+ r"""Create a mfcc from a raw audio signal. This matches the input/output of Kaldi's
704
+ compute-mfcc-feats.
705
+
706
+ Args:
707
+ waveform (Tensor): Tensor of audio of size (c, n) where c is in the range [0,2)
708
+ blackman_coeff (float, optional): Constant coefficient for generalized Blackman window. (Default: ``0.42``)
709
+ cepstral_lifter (float, optional): Constant that controls scaling of MFCCs (Default: ``22.0``)
710
+ channel (int, optional): Channel to extract (-1 -> expect mono, 0 -> left, 1 -> right) (Default: ``-1``)
711
+ dither (float, optional): Dithering constant (0.0 means no dither). If you turn this off, you should set
712
+ the energy_floor option, e.g. to 1.0 or 0.1 (Default: ``0.0``)
713
+ energy_floor (float, optional): Floor on energy (absolute, not relative) in Spectrogram computation. Caution:
714
+ this floor is applied to the zeroth component, representing the total signal energy. The floor on the
715
+ individual spectrogram elements is fixed at std::numeric_limits<float>::epsilon(). (Default: ``1.0``)
716
+ frame_length (float, optional): Frame length in milliseconds (Default: ``25.0``)
717
+ frame_shift (float, optional): Frame shift in milliseconds (Default: ``10.0``)
718
+ high_freq (float, optional): High cutoff frequency for mel bins (if <= 0, offset from Nyquist)
719
+ (Default: ``0.0``)
720
+ htk_compat (bool, optional): If true, put energy last. Warning: not sufficient to get HTK compatible
721
+ features (need to change other parameters). (Default: ``False``)
722
+ low_freq (float, optional): Low cutoff frequency for mel bins (Default: ``20.0``)
723
+ num_ceps (int, optional): Number of cepstra in MFCC computation (including C0) (Default: ``13``)
724
+ min_duration (float, optional): Minimum duration of segments to process (in seconds). (Default: ``0.0``)
725
+ num_mel_bins (int, optional): Number of triangular mel-frequency bins (Default: ``23``)
726
+ preemphasis_coefficient (float, optional): Coefficient for use in signal preemphasis (Default: ``0.97``)
727
+ raw_energy (bool, optional): If True, compute energy before preemphasis and windowing (Default: ``True``)
728
+ remove_dc_offset (bool, optional): Subtract mean from waveform on each frame (Default: ``True``)
729
+ round_to_power_of_two (bool, optional): If True, round window size to power of two by zero-padding input
730
+ to FFT. (Default: ``True``)
731
+ sample_frequency (float, optional): Waveform data sample frequency (must match the waveform file, if
732
+ specified there) (Default: ``16000.0``)
733
+ snip_edges (bool, optional): If True, end effects will be handled by outputting only frames that completely fit
734
+ in the file, and the number of frames depends on the frame_length. If False, the number of frames
735
+ depends only on the frame_shift, and we reflect the data at the ends. (Default: ``True``)
736
+ subtract_mean (bool, optional): Subtract mean of each feature file [CMS]; not recommended to do
737
+ it this way. (Default: ``False``)
738
+ use_energy (bool, optional): Add an extra dimension with energy to the FBANK output. (Default: ``False``)
739
+ vtln_high (float, optional): High inflection point in piecewise linear VTLN warping function (if
740
+ negative, offset from high-mel-freq (Default: ``-500.0``)
741
+ vtln_low (float, optional): Low inflection point in piecewise linear VTLN warping function (Default: ``100.0``)
742
+ vtln_warp (float, optional): Vtln warp factor (only applicable if vtln_map not specified) (Default: ``1.0``)
743
+ window_type (str, optional): Type of window ('hamming'|'hanning'|'povey'|'rectangular'|'blackman')
744
+ (Default: ``"povey"``)
745
+
746
+ Returns:
747
+ Tensor: A mfcc identical to what Kaldi would output. The shape is (m, ``num_ceps``)
748
+ where m is calculated in _get_strided
749
+ """
750
+ assert num_ceps <= num_mel_bins, "num_ceps cannot be larger than num_mel_bins: %d vs %d" % (num_ceps, num_mel_bins)
751
+
752
+ device, dtype = waveform.device, waveform.dtype
753
+
754
+ # The mel_energies should not be squared (use_power=True), not have mean subtracted
755
+ # (subtract_mean=False), and use log (use_log_fbank=True).
756
+ # size (m, num_mel_bins + use_energy)
757
+ feature = fbank(
758
+ waveform=waveform,
759
+ blackman_coeff=blackman_coeff,
760
+ channel=channel,
761
+ dither=dither,
762
+ energy_floor=energy_floor,
763
+ frame_length=frame_length,
764
+ frame_shift=frame_shift,
765
+ high_freq=high_freq,
766
+ htk_compat=htk_compat,
767
+ low_freq=low_freq,
768
+ min_duration=min_duration,
769
+ num_mel_bins=num_mel_bins,
770
+ preemphasis_coefficient=preemphasis_coefficient,
771
+ raw_energy=raw_energy,
772
+ remove_dc_offset=remove_dc_offset,
773
+ round_to_power_of_two=round_to_power_of_two,
774
+ sample_frequency=sample_frequency,
775
+ snip_edges=snip_edges,
776
+ subtract_mean=False,
777
+ use_energy=use_energy,
778
+ use_log_fbank=True,
779
+ use_power=True,
780
+ vtln_high=vtln_high,
781
+ vtln_low=vtln_low,
782
+ vtln_warp=vtln_warp,
783
+ window_type=window_type,
784
+ )
785
+
786
+ if use_energy:
787
+ # size (m)
788
+ signal_log_energy = feature[:, num_mel_bins if htk_compat else 0]
789
+ # offset is 0 if htk_compat==True else 1
790
+ mel_offset = int(not htk_compat)
791
+ feature = feature[:, mel_offset : (num_mel_bins + mel_offset)]
792
+
793
+ # size (num_mel_bins, num_ceps)
794
+ dct_matrix = _get_dct_matrix(num_ceps, num_mel_bins).to(dtype=dtype, device=device)
795
+
796
+ # size (m, num_ceps)
797
+ feature = feature.matmul(dct_matrix)
798
+
799
+ if cepstral_lifter != 0.0:
800
+ # size (1, num_ceps)
801
+ lifter_coeffs = _get_lifter_coeffs(num_ceps, cepstral_lifter).unsqueeze(0)
802
+ feature *= lifter_coeffs.to(device=device, dtype=dtype)
803
+
804
+ # if use_energy then replace the last column for htk_compat == true else first column
805
+ if use_energy:
806
+ feature[:, 0] = signal_log_energy
807
+
808
+ if htk_compat:
809
+ energy = feature[:, 0].unsqueeze(1) # size (m, 1)
810
+ feature = feature[:, 1:] # size (m, num_ceps - 1)
811
+ if not use_energy:
812
+ # scale on C0 (actually removing a scale we previously added that's
813
+ # part of one common definition of the cosine transform.)
814
+ energy *= math.sqrt(2)
815
+
816
+ feature = torch.cat((feature, energy), dim=1)
817
+
818
+ feature = _subtract_column_mean(feature, subtract_mean)
819
+ return feature
eres2net/pooling_layers.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker). All Rights Reserved.
2
+ # Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
3
+
4
+ """ This implementation is adapted from https://github.com/wenet-e2e/wespeaker."""
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+
9
+
10
+ class TAP(nn.Module):
11
+ """
12
+ Temporal average pooling, only first-order mean is considered
13
+ """
14
+ def __init__(self, **kwargs):
15
+ super(TAP, self).__init__()
16
+
17
+ def forward(self, x):
18
+ pooling_mean = x.mean(dim=-1)
19
+ # To be compatable with 2D input
20
+ pooling_mean = pooling_mean.flatten(start_dim=1)
21
+ return pooling_mean
22
+
23
+
24
+ class TSDP(nn.Module):
25
+ """
26
+ Temporal standard deviation pooling, only second-order std is considered
27
+ """
28
+ def __init__(self, **kwargs):
29
+ super(TSDP, self).__init__()
30
+
31
+ def forward(self, x):
32
+ # The last dimension is the temporal axis
33
+ pooling_std = torch.sqrt(torch.var(x, dim=-1) + 1e-8)
34
+ pooling_std = pooling_std.flatten(start_dim=1)
35
+ return pooling_std
36
+
37
+
38
+ class TSTP(nn.Module):
39
+ """
40
+ Temporal statistics pooling, concatenate mean and std, which is used in
41
+ x-vector
42
+ Comment: simple concatenation can not make full use of both statistics
43
+ """
44
+ def __init__(self, **kwargs):
45
+ super(TSTP, self).__init__()
46
+
47
+ def forward(self, x):
48
+ # The last dimension is the temporal axis
49
+ pooling_mean = x.mean(dim=-1)
50
+ pooling_std = torch.sqrt(torch.var(x, dim=-1) + 1e-8)
51
+ pooling_mean = pooling_mean.flatten(start_dim=1)
52
+ pooling_std = pooling_std.flatten(start_dim=1)
53
+
54
+ stats = torch.cat((pooling_mean, pooling_std), 1)
55
+ return stats
56
+
57
+
58
+ class ASTP(nn.Module):
59
+ """ Attentive statistics pooling: Channel- and context-dependent
60
+ statistics pooling, first used in ECAPA_TDNN.
61
+ """
62
+ def __init__(self, in_dim, bottleneck_dim=128, global_context_att=False):
63
+ super(ASTP, self).__init__()
64
+ self.global_context_att = global_context_att
65
+
66
+ # Use Conv1d with stride == 1 rather than Linear, then we don't
67
+ # need to transpose inputs.
68
+ if global_context_att:
69
+ self.linear1 = nn.Conv1d(
70
+ in_dim * 3, bottleneck_dim,
71
+ kernel_size=1) # equals W and b in the paper
72
+ else:
73
+ self.linear1 = nn.Conv1d(
74
+ in_dim, bottleneck_dim,
75
+ kernel_size=1) # equals W and b in the paper
76
+ self.linear2 = nn.Conv1d(bottleneck_dim, in_dim,
77
+ kernel_size=1) # equals V and k in the paper
78
+
79
+ def forward(self, x):
80
+ """
81
+ x: a 3-dimensional tensor in tdnn-based architecture (B,F,T)
82
+ or a 4-dimensional tensor in resnet architecture (B,C,F,T)
83
+ 0-dim: batch-dimension, last-dim: time-dimension (frame-dimension)
84
+ """
85
+ if len(x.shape) == 4:
86
+ x = x.reshape(x.shape[0], x.shape[1] * x.shape[2], x.shape[3])
87
+ assert len(x.shape) == 3
88
+
89
+ if self.global_context_att:
90
+ context_mean = torch.mean(x, dim=-1, keepdim=True).expand_as(x)
91
+ context_std = torch.sqrt(
92
+ torch.var(x, dim=-1, keepdim=True) + 1e-10).expand_as(x)
93
+ x_in = torch.cat((x, context_mean, context_std), dim=1)
94
+ else:
95
+ x_in = x
96
+
97
+ # DON'T use ReLU here! ReLU may be hard to converge.
98
+ alpha = torch.tanh(
99
+ self.linear1(x_in)) # alpha = F.relu(self.linear1(x_in))
100
+ alpha = torch.softmax(self.linear2(alpha), dim=2)
101
+ mean = torch.sum(alpha * x, dim=2)
102
+ var = torch.sum(alpha * (x**2), dim=2) - mean**2
103
+ std = torch.sqrt(var.clamp(min=1e-10))
104
+ return torch.cat([mean, std], dim=1)
inference_webui.py CHANGED
@@ -8,7 +8,6 @@
8
  '''
9
  import logging
10
  import traceback
11
-
12
  logging.getLogger("markdown_it").setLevel(logging.ERROR)
13
  logging.getLogger("urllib3").setLevel(logging.ERROR)
14
  logging.getLogger("httpcore").setLevel(logging.ERROR)
@@ -17,10 +16,13 @@ logging.getLogger("asyncio").setLevel(logging.ERROR)
17
  logging.getLogger("charset_normalizer").setLevel(logging.ERROR)
18
  logging.getLogger("torchaudio._extension").setLevel(logging.ERROR)
19
  logging.getLogger("multipart.multipart").setLevel(logging.ERROR)
 
 
 
20
  import gradio.analytics as analytics
21
  analytics.version_check = lambda:None
22
  analytics.get_local_ip_address= lambda :"127.0.0.1"##不干掉本地联不通亚马逊的get_local_ip服务器
23
- import nltk
24
  nltk.download('averaged_perceptron_tagger_eng')
25
  import LangSegment, os, re, sys, json
26
  import pdb
@@ -190,7 +192,7 @@ def change_sovits_weights(sovits_path,prompt_language=None,text_language=None):
190
 
191
 
192
 
193
- change_sovits_weights("pretrained_models/gsv-v2final-pretrained/s2G2333k.pth")
194
 
195
 
196
  def change_gpt_weights(gpt_path):
@@ -209,27 +211,53 @@ def change_gpt_weights(gpt_path):
209
  print("Number of parameter: %.2fM" % (total / 1e6))
210
 
211
 
212
- change_gpt_weights("pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt")
 
 
 
 
 
213
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
214
 
215
- def get_spepc(hps, filename):
216
- audio = load_audio(filename, int(hps.data.sampling_rate))
217
- audio = torch.FloatTensor(audio)
218
- maxx=audio.abs().max()
219
- if(maxx>1):audio/=min(2,maxx)
220
- audio_norm = audio
221
- audio_norm = audio_norm.unsqueeze(0)
222
  spec = spectrogram_torch(
223
- audio_norm,
224
  hps.data.filter_length,
225
  hps.data.sampling_rate,
226
  hps.data.hop_length,
227
  hps.data.win_length,
228
  center=False,
229
  )
230
- return spec
 
 
 
 
231
 
232
  def clean_text_inf(text, language, version):
 
233
  phones, word2ph, norm_text = clean_text(text, language, version)
234
  phones = cleaned_text_to_sequence(phones, version)
235
  return phones, word2ph, norm_text
@@ -257,29 +285,24 @@ def get_first(text):
257
  return text
258
 
259
  from text import chinese
260
- def get_phones_and_bert(text,language,version):
 
261
  if language in {"en", "all_zh", "all_ja", "all_ko", "all_yue"}:
262
- language = language.replace("all_","")
263
- if language == "en":
264
- LangSegment.setfilters(["en"])
265
- formattext = " ".join(tmp["text"] for tmp in LangSegment.getTexts(text))
266
- else:
267
- # 因无法区别中日韩文汉字,以用户输入为准
268
- formattext = text
269
  while " " in formattext:
270
  formattext = formattext.replace(" ", " ")
271
- if language == "zh":
272
- if re.search(r'[A-Za-z]', formattext):
273
- formattext = re.sub(r'[a-z]', lambda x: x.group(0).upper(), formattext)
274
  formattext = chinese.mix_text_normalize(formattext)
275
- return get_phones_and_bert(formattext,"zh",version)
276
  else:
277
  phones, word2ph, norm_text = clean_text_inf(formattext, language, version)
278
  bert = get_bert_feature(norm_text, word2ph).to(device)
279
- elif language == "yue" and re.search(r'[A-Za-z]', formattext):
280
- formattext = re.sub(r'[a-z]', lambda x: x.group(0).upper(), formattext)
281
- formattext = chinese.mix_text_normalize(formattext)
282
- return get_phones_and_bert(formattext,"yue",version)
283
  else:
284
  phones, word2ph, norm_text = clean_text_inf(formattext, language, version)
285
  bert = torch.zeros(
@@ -287,21 +310,20 @@ def get_phones_and_bert(text,language,version):
287
  dtype=torch.float16 if is_half == True else torch.float32,
288
  ).to(device)
289
  elif language in {"zh", "ja", "ko", "yue", "auto", "auto_yue"}:
290
- textlist=[]
291
- langlist=[]
292
- LangSegment.setfilters(["zh","ja","en","ko"])
293
  if language == "auto":
294
- for tmp in LangSegment.getTexts(text):
295
  langlist.append(tmp["lang"])
296
  textlist.append(tmp["text"])
297
  elif language == "auto_yue":
298
- for tmp in LangSegment.getTexts(text):
299
  if tmp["lang"] == "zh":
300
  tmp["lang"] = "yue"
301
  langlist.append(tmp["lang"])
302
  textlist.append(tmp["text"])
303
  else:
304
- for tmp in LangSegment.getTexts(text):
305
  if tmp["lang"] == "en":
306
  langlist.append(tmp["lang"])
307
  else:
@@ -322,9 +344,12 @@ def get_phones_and_bert(text,language,version):
322
  bert_list.append(bert)
323
  bert = torch.cat(bert_list, dim=1)
324
  phones = sum(phones_list, [])
325
- norm_text = ''.join(norm_text_list)
 
 
 
326
 
327
- return phones,bert.to(dtype),norm_text
328
 
329
 
330
  def merge_short_text_in_array(texts, threshold):
@@ -461,15 +486,22 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
461
  cache[i_text]=pred_semantic
462
  t3 = ttime()
463
  refers=[]
 
464
  if(inp_refs):
465
  for path in inp_refs:
466
  try:
467
- refer = get_spepc(hps, path.name).to(dtype).to(device)
468
  refers.append(refer)
 
469
  except:
470
  traceback.print_exc()
471
- if(len(refers)==0):refers = [get_spepc(hps, ref_wav_path).to(dtype).to(device)]
472
- audio = (vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refers,speed=speed).detach().cpu().numpy()[0, 0])
 
 
 
 
 
473
  max_audio=np.abs(audio).max()#简单防止16bit爆音
474
  if max_audio>1:audio/=max_audio
475
  audio_opt.append(audio)
@@ -674,5 +706,5 @@ if __name__ == '__main__':
674
  inbrowser=True,
675
  # share=True,
676
  # server_port=infer_ttswebui,
677
- quiet=True,
678
  )
 
8
  '''
9
  import logging
10
  import traceback
 
11
  logging.getLogger("markdown_it").setLevel(logging.ERROR)
12
  logging.getLogger("urllib3").setLevel(logging.ERROR)
13
  logging.getLogger("httpcore").setLevel(logging.ERROR)
 
16
  logging.getLogger("charset_normalizer").setLevel(logging.ERROR)
17
  logging.getLogger("torchaudio._extension").setLevel(logging.ERROR)
18
  logging.getLogger("multipart.multipart").setLevel(logging.ERROR)
19
+ logging.getLogger("python_multipart.multipart").setLevel(logging.ERROR)
20
+ logging.getLogger("split_lang.split.splitter").setLevel(logging.ERROR)
21
+ from text.LangSegmenter import LangSegmenter
22
  import gradio.analytics as analytics
23
  analytics.version_check = lambda:None
24
  analytics.get_local_ip_address= lambda :"127.0.0.1"##不干掉本地联不通亚马逊的get_local_ip服务器
25
+ import nltk,torchaudio
26
  nltk.download('averaged_perceptron_tagger_eng')
27
  import LangSegment, os, re, sys, json
28
  import pdb
 
192
 
193
 
194
 
195
+ change_sovits_weights("pretrained_models/v2Pro/s2Gv2ProPlus.pth")
196
 
197
 
198
  def change_gpt_weights(gpt_path):
 
211
  print("Number of parameter: %.2fM" % (total / 1e6))
212
 
213
 
214
+ change_gpt_weights("pretrained_models/s1v3.ckpt")
215
+ from sv import SV
216
+ sv_cn_model = SV(device, is_half)
217
+
218
+ resample_transform_dict = {}
219
+
220
 
221
+ def resample(audio_tensor, sr0, sr1, device):
222
+ global resample_transform_dict
223
+ key = "%s-%s-%s" % (sr0, sr1, str(device))
224
+ if key not in resample_transform_dict:
225
+ resample_transform_dict[key] = torchaudio.transforms.Resample(sr0, sr1).to(device)
226
+ return resample_transform_dict[key](audio_tensor)
227
+
228
+
229
+ def get_spepc(hps, filename, dtype, device, is_v2pro=False):
230
+ sr1 = int(hps.data.sampling_rate)
231
+ audio, sr0 = torchaudio.load(filename)
232
+ if sr0 != sr1:
233
+ audio = audio.to(device)
234
+ if audio.shape[0] == 2:
235
+ audio = audio.mean(0).unsqueeze(0)
236
+ audio = resample(audio, sr0, sr1, device)
237
+ else:
238
+ audio = audio.to(device)
239
+ if audio.shape[0] == 2:
240
+ audio = audio.mean(0).unsqueeze(0)
241
 
242
+ maxx = audio.abs().max()
243
+ if maxx > 1:
244
+ audio /= min(2, maxx)
 
 
 
 
245
  spec = spectrogram_torch(
246
+ audio,
247
  hps.data.filter_length,
248
  hps.data.sampling_rate,
249
  hps.data.hop_length,
250
  hps.data.win_length,
251
  center=False,
252
  )
253
+ spec = spec.to(dtype)
254
+ if is_v2pro == True:
255
+ audio = resample(audio, sr1, 16000, device).to(dtype)
256
+ return spec, audio
257
+
258
 
259
  def clean_text_inf(text, language, version):
260
+ language = language.replace("all_", "")
261
  phones, word2ph, norm_text = clean_text(text, language, version)
262
  phones = cleaned_text_to_sequence(phones, version)
263
  return phones, word2ph, norm_text
 
285
  return text
286
 
287
  from text import chinese
288
+
289
+ def get_phones_and_bert(text, language, version, final=False):
290
  if language in {"en", "all_zh", "all_ja", "all_ko", "all_yue"}:
291
+ formattext = text
 
 
 
 
 
 
292
  while " " in formattext:
293
  formattext = formattext.replace(" ", " ")
294
+ if language == "all_zh":
295
+ if re.search(r"[A-Za-z]", formattext):
296
+ formattext = re.sub(r"[a-z]", lambda x: x.group(0).upper(), formattext)
297
  formattext = chinese.mix_text_normalize(formattext)
298
+ return get_phones_and_bert(formattext, "zh", version)
299
  else:
300
  phones, word2ph, norm_text = clean_text_inf(formattext, language, version)
301
  bert = get_bert_feature(norm_text, word2ph).to(device)
302
+ elif language == "all_yue" and re.search(r"[A-Za-z]", formattext):
303
+ formattext = re.sub(r"[a-z]", lambda x: x.group(0).upper(), formattext)
304
+ formattext = chinese.mix_text_normalize(formattext)
305
+ return get_phones_and_bert(formattext, "yue", version)
306
  else:
307
  phones, word2ph, norm_text = clean_text_inf(formattext, language, version)
308
  bert = torch.zeros(
 
310
  dtype=torch.float16 if is_half == True else torch.float32,
311
  ).to(device)
312
  elif language in {"zh", "ja", "ko", "yue", "auto", "auto_yue"}:
313
+ textlist = []
314
+ langlist = []
 
315
  if language == "auto":
316
+ for tmp in LangSegmenter.getTexts(text):
317
  langlist.append(tmp["lang"])
318
  textlist.append(tmp["text"])
319
  elif language == "auto_yue":
320
+ for tmp in LangSegmenter.getTexts(text):
321
  if tmp["lang"] == "zh":
322
  tmp["lang"] = "yue"
323
  langlist.append(tmp["lang"])
324
  textlist.append(tmp["text"])
325
  else:
326
+ for tmp in LangSegmenter.getTexts(text):
327
  if tmp["lang"] == "en":
328
  langlist.append(tmp["lang"])
329
  else:
 
344
  bert_list.append(bert)
345
  bert = torch.cat(bert_list, dim=1)
346
  phones = sum(phones_list, [])
347
+ norm_text = "".join(norm_text_list)
348
+
349
+ if not final and len(phones) < 6:
350
+ return get_phones_and_bert("." + text, language, version, final=True)
351
 
352
+ return phones, bert.to(dtype), norm_text
353
 
354
 
355
  def merge_short_text_in_array(texts, threshold):
 
486
  cache[i_text]=pred_semantic
487
  t3 = ttime()
488
  refers=[]
489
+ sv_emb = []
490
  if(inp_refs):
491
  for path in inp_refs:
492
  try:
493
+ refer, audio_tensor = get_spepc(hps, path.name, dtype, device, is_v2pro=True)
494
  refers.append(refer)
495
+ sv_emb.append(sv_cn_model.compute_embedding3(audio_tensor))
496
  except:
497
  traceback.print_exc()
498
+ if len(refers) == 0:
499
+ refers, audio_tensor = get_spepc(hps, ref_wav_path, dtype, device, is_v2pro=True)
500
+ refers = [refers]
501
+ sv_emb = [sv_cn_model.compute_embedding3(audio_tensor)]
502
+ audio = vq_model.decode(
503
+ pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refers, speed=speed, sv_emb=sv_emb
504
+ ).detach().cpu().numpy()[0][0]
505
  max_audio=np.abs(audio).max()#简单防止16bit爆音
506
  if max_audio>1:audio/=max_audio
507
  audio_opt.append(audio)
 
706
  inbrowser=True,
707
  # share=True,
708
  # server_port=infer_ttswebui,
709
+ # quiet=True,
710
  )
module/models.py CHANGED
@@ -912,6 +912,9 @@ class SynthesizerTrn(nn.Module):
912
 
913
  self.quantizer = ResidualVectorQuantizer(dimension=ssl_dim, n_q=1, bins=1024)
914
  self.freeze_quantizer = freeze_quantizer
 
 
 
915
 
916
  def forward(self, ssl, y, y_lengths, text, text_lengths):
917
  y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(
@@ -921,6 +924,10 @@ class SynthesizerTrn(nn.Module):
921
  ge = self.ref_enc(y * y_mask, y_mask)
922
  else:
923
  ge = self.ref_enc(y[:,:704] * y_mask, y_mask)
 
 
 
 
924
  with autocast(enabled=False):
925
  maybe_no_grad = torch.no_grad() if self.freeze_quantizer else contextlib.nullcontext()
926
  with maybe_no_grad:
@@ -938,7 +945,7 @@ class SynthesizerTrn(nn.Module):
938
  )
939
 
940
  x, m_p, logs_p, y_mask = self.enc_p(
941
- quantized, y_lengths, text, text_lengths, ge
942
  )
943
  z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=ge)
944
  z_p = self.flow(z, y_mask, g=ge)
@@ -984,8 +991,8 @@ class SynthesizerTrn(nn.Module):
984
  return o, y_mask, (z, z_p, m_p, logs_p)
985
 
986
  @torch.no_grad()
987
- def decode(self, codes, text, refer, noise_scale=0.5,speed=1):
988
- def get_ge(refer):
989
  ge = None
990
  if refer is not None:
991
  refer_lengths = torch.LongTensor([refer.size(2)]).to(refer.device)
@@ -996,15 +1003,18 @@ class SynthesizerTrn(nn.Module):
996
  ge = self.ref_enc(refer * refer_mask, refer_mask)
997
  else:
998
  ge = self.ref_enc(refer[:, :704] * refer_mask, refer_mask)
 
 
 
999
  return ge
1000
  if(type(refer)==list):
1001
  ges=[]
1002
- for _refer in refer:
1003
- ge=get_ge(_refer)
1004
  ges.append(ge)
1005
  ge=torch.stack(ges,0).mean(0)
1006
  else:
1007
- ge=get_ge(refer)
1008
 
1009
  y_lengths = torch.LongTensor([codes.size(2) * 2]).to(codes.device)
1010
  text_lengths = torch.LongTensor([text.size(-1)]).to(text.device)
@@ -1015,7 +1025,7 @@ class SynthesizerTrn(nn.Module):
1015
  quantized, size=int(quantized.shape[-1] * 2), mode="nearest"
1016
  )
1017
  x, m_p, logs_p, y_mask = self.enc_p(
1018
- quantized, y_lengths, text, text_lengths, ge,speed
1019
  )
1020
  z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
1021
 
 
912
 
913
  self.quantizer = ResidualVectorQuantizer(dimension=ssl_dim, n_q=1, bins=1024)
914
  self.freeze_quantizer = freeze_quantizer
915
+ self.sv_emb = nn.Linear(20480, gin_channels)
916
+ self.ge_to512 = nn.Linear(gin_channels, 512)
917
+ self.prelu = nn.PReLU(num_parameters=gin_channels)
918
 
919
  def forward(self, ssl, y, y_lengths, text, text_lengths):
920
  y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(
 
924
  ge = self.ref_enc(y * y_mask, y_mask)
925
  else:
926
  ge = self.ref_enc(y[:,:704] * y_mask, y_mask)
927
+ sv_emb = self.sv_emb(sv_emb) # B*20480->B*512
928
+ ge += sv_emb.unsqueeze(-1)
929
+ ge = self.prelu(ge)
930
+ ge512 = self.ge_to512(ge.transpose(2, 1)).transpose(2, 1)
931
  with autocast(enabled=False):
932
  maybe_no_grad = torch.no_grad() if self.freeze_quantizer else contextlib.nullcontext()
933
  with maybe_no_grad:
 
945
  )
946
 
947
  x, m_p, logs_p, y_mask = self.enc_p(
948
+ quantized, y_lengths, text, text_lengths, ge512
949
  )
950
  z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=ge)
951
  z_p = self.flow(z, y_mask, g=ge)
 
991
  return o, y_mask, (z, z_p, m_p, logs_p)
992
 
993
  @torch.no_grad()
994
+ def decode(self, codes, text, refer, noise_scale=0.5,speed=1, sv_emb=None):
995
+ def get_ge(refer, sv_emb):
996
  ge = None
997
  if refer is not None:
998
  refer_lengths = torch.LongTensor([refer.size(2)]).to(refer.device)
 
1003
  ge = self.ref_enc(refer * refer_mask, refer_mask)
1004
  else:
1005
  ge = self.ref_enc(refer[:, :704] * refer_mask, refer_mask)
1006
+ sv_emb = self.sv_emb(sv_emb) # B*20480->B*512
1007
+ ge += sv_emb.unsqueeze(-1)
1008
+ ge = self.prelu(ge)
1009
  return ge
1010
  if(type(refer)==list):
1011
  ges=[]
1012
+ for idx,_refer in enumerate(refer):
1013
+ ge=get_ge(_refer,sv_emb[idx])
1014
  ges.append(ge)
1015
  ge=torch.stack(ges,0).mean(0)
1016
  else:
1017
+ ge = get_ge(refer, sv_emb)
1018
 
1019
  y_lengths = torch.LongTensor([codes.size(2) * 2]).to(codes.device)
1020
  text_lengths = torch.LongTensor([text.size(-1)]).to(text.device)
 
1025
  quantized, size=int(quantized.shape[-1] * 2), mode="nearest"
1026
  )
1027
  x, m_p, logs_p, y_mask = self.enc_p(
1028
+ quantized, y_lengths, text, text_lengths, self.ge_to512(ge.transpose(2,1)).transpose(2,1),speed
1029
  )
1030
  z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
1031
 
requirements.txt CHANGED
@@ -5,7 +5,8 @@ librosa==0.9.2
5
  numba==0.56.4
6
  torchaudio
7
  pytorch-lightning>=2.4
8
- gradio<5
 
9
  ffmpeg-python==0.2.0
10
  onnxruntime-gpu
11
  tqdm==4.66.4
@@ -14,7 +15,7 @@ pypinyin==0.50.0
14
  pyopenjtalk==0.4.1
15
  g2p_en==2.1.0
16
  sentencepiece==0.1.99
17
- transformers==4.35.0
18
  chardet==3.0.4
19
  PyYAML==6.0.1
20
  psutil==5.9.7
@@ -30,5 +31,9 @@ ko_pron==1.3
30
  opencc==1.1.0
31
  python_mecab_ko==1.3.7
32
  torch==2.4
33
- pydantic<=2.10.6
34
- torchmetrics<=1.5
 
 
 
 
 
5
  numba==0.56.4
6
  torchaudio
7
  pytorch-lightning>=2.4
8
+ gradio==4.44.1
9
+ gradio_client==1.3.0
10
  ffmpeg-python==0.2.0
11
  onnxruntime-gpu
12
  tqdm==4.66.4
 
15
  pyopenjtalk==0.4.1
16
  g2p_en==2.1.0
17
  sentencepiece==0.1.99
18
+ transformers==4.43.0
19
  chardet==3.0.4
20
  PyYAML==6.0.1
21
  psutil==5.9.7
 
31
  opencc==1.1.0
32
  python_mecab_ko==1.3.7
33
  torch==2.4
34
+ pydantic==2.8.2
35
+ torchmetrics<=1.5
36
+ nltk==3.8.1
37
+ fast_langdetect==0.3.1
38
+ split_lang==2.1.0
39
+ ToJyutping==3.2.0
sv.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys,os,torch
2
+ sys.path.append(f"{os.getcwd()}/eres2net")
3
+ sv_path = "pretrained_models/sv/pretrained_eres2netv2w24s4ep4.ckpt"
4
+ from ERes2NetV2 import ERes2NetV2
5
+ import kaldi as Kaldi
6
+ class SV:
7
+ def __init__(self,device,is_half):
8
+ pretrained_state = torch.load(sv_path, map_location='cpu', weights_only=False)
9
+ embedding_model = ERes2NetV2(baseWidth=24,scale=4,expansion=4)
10
+ embedding_model.load_state_dict(pretrained_state)
11
+ embedding_model.eval()
12
+ self.embedding_model=embedding_model
13
+ if is_half == False:
14
+ self.embedding_model=self.embedding_model.to(device)
15
+ else:
16
+ self.embedding_model=self.embedding_model.half().to(device)
17
+ self.is_half=is_half
18
+
19
+ def compute_embedding3(self,wav):
20
+ with torch.no_grad():
21
+ if self.is_half==True:wav=wav.half()
22
+ feat = torch.stack([Kaldi.fbank(wav0.unsqueeze(0), num_mel_bins=80, sample_frequency=16000, dither=0) for wav0 in wav])
23
+ sv_emb = self.embedding_model.forward3(feat)
24
+ return sv_emb
utils.py CHANGED
@@ -1,4 +1,4 @@
1
- import os
2
  import glob
3
  import sys
4
  import argparse
 
1
+ import os##
2
  import glob
3
  import sys
4
  import argparse