Spaces:
Running
on
L40S
Running
on
L40S
root
commited on
Commit
·
98a0e3b
1
Parent(s):
167c6ec
add lowmem mode
Browse files- codeclm/models/builders.py +18 -2
- codeclm/models/codeclm.py +10 -7
- codeclm/tokenizer/Flow1dVAE/generate_1rvq.py +8 -1
- codeclm/tokenizer/Flow1dVAE/generate_2rvq.py +0 -1
- codeclm/tokenizer/Flow1dVAE/generate_4rvq.py +0 -1
- codeclm/tokenizer/Flow1dVAE/generate_septoken.py +11 -4
- codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/models/musicfm/model/musicfm_25hz.py +2 -45
- codeclm/tokenizer/audio_tokenizer.py +30 -7
- codeclm/trainer/codec_song_pl.py +1 -4
- codeclm/utils/offload_profiler.py +505 -0
- download.py +1 -1
- generate.py +440 -42
- generate.sh +63 -2
- generate_lowmem.py +0 -241
- generate_lowmem.sh +0 -11
- tools/gradio/app.py +1 -2
- tools/gradio/levo_inference.py +1 -1
- tools/gradio/levo_inference_lowmem.py +65 -23
codeclm/models/builders.py
CHANGED
@@ -29,13 +29,29 @@ def get_audio_tokenizer_model(checkpoint_path: str, cfg: omegaconf.DictConfig):
|
|
29 |
return None
|
30 |
if checkpoint_path.startswith('//pretrained/'):
|
31 |
name = checkpoint_path.split('/', 3)[-1]
|
32 |
-
return AudioTokenizer.get_pretrained(name, cfg.vae_config, cfg.vae_model, '
|
33 |
elif checkpoint_path == "":
|
34 |
return None
|
35 |
else:
|
36 |
name = checkpoint_path
|
37 |
-
return AudioTokenizer.get_pretrained(name, cfg.vae_config, cfg.vae_model, '
|
38 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
def get_lm_model(cfg: omegaconf.DictConfig): #-> LMModel:
|
40 |
"""Instantiate a LM."""
|
41 |
lm_kwargs = dict_from_config(getattr(cfg, 'lm'))
|
|
|
29 |
return None
|
30 |
if checkpoint_path.startswith('//pretrained/'):
|
31 |
name = checkpoint_path.split('/', 3)[-1]
|
32 |
+
return AudioTokenizer.get_pretrained(name, cfg.vae_config, cfg.vae_model, 'cuda', mode=cfg.mode)
|
33 |
elif checkpoint_path == "":
|
34 |
return None
|
35 |
else:
|
36 |
name = checkpoint_path
|
37 |
+
return AudioTokenizer.get_pretrained(name, cfg.vae_config, cfg.vae_model, 'cuda', mode=cfg.mode)
|
38 |
|
39 |
+
|
40 |
+
def get_audio_tokenizer_model_cpu(checkpoint_path: str, cfg: omegaconf.DictConfig):
|
41 |
+
from codeclm.tokenizer.audio_tokenizer import AudioTokenizer
|
42 |
+
"""Instantiate a compression model."""
|
43 |
+
if checkpoint_path is None:
|
44 |
+
return None
|
45 |
+
if checkpoint_path.startswith('//pretrained/'):
|
46 |
+
name = checkpoint_path.split('/', 3)[-1]
|
47 |
+
return AudioTokenizer.get_pretrained(name, cfg.vae_config, cfg.vae_model, 'cpu', mode=cfg.mode, tango_device='cpu')
|
48 |
+
elif checkpoint_path == "":
|
49 |
+
return None
|
50 |
+
else:
|
51 |
+
name = checkpoint_path
|
52 |
+
return AudioTokenizer.get_pretrained(name, cfg.vae_config, cfg.vae_model, 'cpu', mode=cfg.mode, tango_device='cpu')
|
53 |
+
|
54 |
+
|
55 |
def get_lm_model(cfg: omegaconf.DictConfig): #-> LMModel:
|
56 |
"""Instantiate a LM."""
|
57 |
lm_kwargs = dict_from_config(getattr(cfg, 'lm'))
|
codeclm/models/codeclm.py
CHANGED
@@ -271,21 +271,24 @@ class CodecLM:
|
|
271 |
return gen_tokens
|
272 |
|
273 |
@torch.no_grad()
|
274 |
-
def generate_audio(self, gen_tokens: torch.Tensor, prompt=None, vocal_prompt=None, bgm_prompt=None, chunked=False, gen_type=
|
275 |
"""Generate Audio from tokens"""
|
276 |
assert gen_tokens.dim() == 3
|
277 |
if self.seperate_tokenizer is not None:
|
278 |
gen_tokens_song = gen_tokens[:, [0], :]
|
279 |
gen_tokens_vocal = gen_tokens[:, [1], :]
|
280 |
gen_tokens_bgm = gen_tokens[:, [2], :]
|
281 |
-
if gen_type ==
|
282 |
gen_tokens_vocal = torch.full_like(gen_tokens_vocal, 3142)
|
283 |
-
vocal_prompt
|
284 |
-
|
|
|
285 |
gen_tokens_bgm = torch.full_like(gen_tokens_bgm, 9670)
|
286 |
-
bgm_prompt
|
287 |
-
|
288 |
-
|
|
|
|
|
289 |
return gen_audio_seperate
|
290 |
else:
|
291 |
gen_audio = self.audiotokenizer.decode(gen_tokens, prompt)
|
|
|
271 |
return gen_tokens
|
272 |
|
273 |
@torch.no_grad()
|
274 |
+
def generate_audio(self, gen_tokens: torch.Tensor, prompt=None, vocal_prompt=None, bgm_prompt=None, chunked=False, chunk_size=128, gen_type='mixed'):
|
275 |
"""Generate Audio from tokens"""
|
276 |
assert gen_tokens.dim() == 3
|
277 |
if self.seperate_tokenizer is not None:
|
278 |
gen_tokens_song = gen_tokens[:, [0], :]
|
279 |
gen_tokens_vocal = gen_tokens[:, [1], :]
|
280 |
gen_tokens_bgm = gen_tokens[:, [2], :]
|
281 |
+
if gen_type == 'bgm':
|
282 |
gen_tokens_vocal = torch.full_like(gen_tokens_vocal, 3142)
|
283 |
+
if vocal_prompt is not None:
|
284 |
+
vocal_prompt = torch.zeros_like(vocal_prompt)
|
285 |
+
elif gen_type == 'vocal':
|
286 |
gen_tokens_bgm = torch.full_like(gen_tokens_bgm, 9670)
|
287 |
+
if bgm_prompt is not None:
|
288 |
+
bgm_prompt = torch.zeros_like(bgm_prompt)
|
289 |
+
else:
|
290 |
+
assert gen_type == 'mixed', f"gen_type {gen_type} not supported"
|
291 |
+
gen_audio_seperate = self.seperate_tokenizer.decode([gen_tokens_vocal, gen_tokens_bgm], vocal_prompt, bgm_prompt, chunked=chunked, chunk_size=chunk_size)
|
292 |
return gen_audio_seperate
|
293 |
else:
|
294 |
gen_audio = self.audiotokenizer.decode(gen_tokens, prompt)
|
codeclm/tokenizer/Flow1dVAE/generate_1rvq.py
CHANGED
@@ -46,7 +46,6 @@ class Tango:
|
|
46 |
|
47 |
self.model.eval()
|
48 |
self.model.init_device_dtype(torch.device(device), torch.float32)
|
49 |
-
print("scaling factor: ", self.model.normfeat.std)
|
50 |
|
51 |
# self.scheduler = DDIMScheduler.from_pretrained( \
|
52 |
# scheduler_name, subfolder="scheduler")
|
@@ -281,3 +280,11 @@ class Tango:
|
|
281 |
else:
|
282 |
output = torch.cat([output, cur_output], -1)
|
283 |
return output
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
|
47 |
self.model.eval()
|
48 |
self.model.init_device_dtype(torch.device(device), torch.float32)
|
|
|
49 |
|
50 |
# self.scheduler = DDIMScheduler.from_pretrained( \
|
51 |
# scheduler_name, subfolder="scheduler")
|
|
|
280 |
else:
|
281 |
output = torch.cat([output, cur_output], -1)
|
282 |
return output
|
283 |
+
|
284 |
+
def to(self, device=None, dtype=None, non_blocking=False):
|
285 |
+
if device is not None:
|
286 |
+
self.device = device
|
287 |
+
self.model.device = device
|
288 |
+
self.vae = self.vae.to(device, dtype, non_blocking)
|
289 |
+
self.model = self.model.to(device, dtype, non_blocking)
|
290 |
+
return self
|
codeclm/tokenizer/Flow1dVAE/generate_2rvq.py
CHANGED
@@ -51,7 +51,6 @@ class Tango:
|
|
51 |
|
52 |
self.model.eval()
|
53 |
self.model.init_device_dtype(torch.device(device), torch.float32)
|
54 |
-
print("scaling factor: ", self.model.normfeat.std)
|
55 |
|
56 |
# self.scheduler = DDIMScheduler.from_pretrained( \
|
57 |
# scheduler_name, subfolder="scheduler")
|
|
|
51 |
|
52 |
self.model.eval()
|
53 |
self.model.init_device_dtype(torch.device(device), torch.float32)
|
|
|
54 |
|
55 |
# self.scheduler = DDIMScheduler.from_pretrained( \
|
56 |
# scheduler_name, subfolder="scheduler")
|
codeclm/tokenizer/Flow1dVAE/generate_4rvq.py
CHANGED
@@ -50,7 +50,6 @@ class Tango:
|
|
50 |
|
51 |
self.model.eval()
|
52 |
self.model.init_device_dtype(torch.device(device), torch.float32)
|
53 |
-
print("scaling factor: ", self.model.normfeat.std)
|
54 |
|
55 |
# self.scheduler = DDIMScheduler.from_pretrained( \
|
56 |
# scheduler_name, subfolder="scheduler")
|
|
|
50 |
|
51 |
self.model.eval()
|
52 |
self.model.init_device_dtype(torch.device(device), torch.float32)
|
|
|
53 |
|
54 |
# self.scheduler = DDIMScheduler.from_pretrained( \
|
55 |
# scheduler_name, subfolder="scheduler")
|
codeclm/tokenizer/Flow1dVAE/generate_septoken.py
CHANGED
@@ -102,7 +102,6 @@ class Tango:
|
|
102 |
|
103 |
self.model.eval()
|
104 |
self.model.init_device_dtype(torch.device(device), torch.float32)
|
105 |
-
print("scaling factor: ", self.model.normfeat.std)
|
106 |
|
107 |
# self.scheduler = DDIMScheduler.from_pretrained( \
|
108 |
# scheduler_name, subfolder="scheduler")
|
@@ -173,7 +172,7 @@ class Tango:
|
|
173 |
return codes_vocal, codes_bgm
|
174 |
|
175 |
@torch.no_grad()
|
176 |
-
def code2sound(self, codes, prompt_vocal=None, prompt_bgm=None, duration=40, guidance_scale=1.5, num_steps=20, disable_progress=False, chunked=False):
|
177 |
codes_vocal,codes_bgm = codes
|
178 |
codes_vocal = codes_vocal.to(self.device)
|
179 |
codes_bgm = codes_bgm.to(self.device)
|
@@ -188,7 +187,7 @@ class Tango:
|
|
188 |
first_latent_codes_length = 0
|
189 |
|
190 |
|
191 |
-
if
|
192 |
# prepare prompt
|
193 |
prompt_vocal = prompt_vocal.to(self.device)
|
194 |
prompt_bgm = prompt_bgm.to(self.device)
|
@@ -273,7 +272,7 @@ class Tango:
|
|
273 |
output = None
|
274 |
for i in range(len(latent_list)):
|
275 |
latent = latent_list[i]
|
276 |
-
cur_output = self.vae.decode_audio(latent, chunked=chunked)[0].detach().cpu()
|
277 |
|
278 |
if output is None:
|
279 |
output = cur_output
|
@@ -301,3 +300,11 @@ class Tango:
|
|
301 |
codes=[codes_vocal, codes_bgm]
|
302 |
wave = self.code2sound(codes, prompt_vocal,prompt_bgm, guidance_scale=1.5, num_steps=steps, disable_progress=disable_progress)
|
303 |
return wave
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
102 |
|
103 |
self.model.eval()
|
104 |
self.model.init_device_dtype(torch.device(device), torch.float32)
|
|
|
105 |
|
106 |
# self.scheduler = DDIMScheduler.from_pretrained( \
|
107 |
# scheduler_name, subfolder="scheduler")
|
|
|
172 |
return codes_vocal, codes_bgm
|
173 |
|
174 |
@torch.no_grad()
|
175 |
+
def code2sound(self, codes, prompt_vocal=None, prompt_bgm=None, duration=40, guidance_scale=1.5, num_steps=20, disable_progress=False, chunked=False, chunk_size=128):
|
176 |
codes_vocal,codes_bgm = codes
|
177 |
codes_vocal = codes_vocal.to(self.device)
|
178 |
codes_bgm = codes_bgm.to(self.device)
|
|
|
187 |
first_latent_codes_length = 0
|
188 |
|
189 |
|
190 |
+
if(isinstance(prompt_vocal, torch.Tensor) and isinstance(prompt_bgm, torch.Tensor)):
|
191 |
# prepare prompt
|
192 |
prompt_vocal = prompt_vocal.to(self.device)
|
193 |
prompt_bgm = prompt_bgm.to(self.device)
|
|
|
272 |
output = None
|
273 |
for i in range(len(latent_list)):
|
274 |
latent = latent_list[i]
|
275 |
+
cur_output = self.vae.decode_audio(latent, chunked=chunked, chunk_size=chunk_size)[0].detach().cpu()
|
276 |
|
277 |
if output is None:
|
278 |
output = cur_output
|
|
|
300 |
codes=[codes_vocal, codes_bgm]
|
301 |
wave = self.code2sound(codes, prompt_vocal,prompt_bgm, guidance_scale=1.5, num_steps=steps, disable_progress=disable_progress)
|
302 |
return wave
|
303 |
+
|
304 |
+
def to(self, device=None, dtype=None, non_blocking=False):
|
305 |
+
if device is not None:
|
306 |
+
self.device = device
|
307 |
+
self.model.device = device
|
308 |
+
self.vae = self.vae.to(device, dtype, non_blocking)
|
309 |
+
self.model = self.model.to(device, dtype, non_blocking)
|
310 |
+
return self
|
codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/models/musicfm/model/musicfm_25hz.py
CHANGED
@@ -78,7 +78,6 @@ class MusicFM25Hz(nn.Module):
|
|
78 |
with open(stat_path, "r") as f:
|
79 |
self.stat = json.load(f)
|
80 |
else:
|
81 |
-
print("No stats file found at `{}`, use default from msd.".format(stat_path))
|
82 |
self.stat = {"spec_256_cnt": 14394344256, "spec_256_mean": -23.34296658431829, "spec_256_std": 26.189295587132637, "spec_512_cnt": 28677104448, "spec_512_mean": -21.31267396860235, "spec_512_std": 26.52644536245769, "spec_1024_cnt": 57242624832, "spec_1024_mean": -18.852271129208273, "spec_1024_std": 26.443154583585663, "spec_2048_cnt": 114373665600, "spec_2048_mean": -15.638743433896792, "spec_2048_std": 26.115825961611545, "spec_4096_cnt": 228635747136, "spec_4096_mean": -11.715532502794836, "spec_4096_std": 25.763972210234062, "melspec_256_cnt": 14282760192, "melspec_256_mean": -26.962600400166156, "melspec_256_std": 36.13614100912126, "melspec_512_cnt": 14282760192, "melspec_512_mean": -9.108344167718862, "melspec_512_std": 24.71910937988429, "melspec_1024_cnt": 14282760192, "melspec_1024_mean": 0.37302579246531126, "melspec_1024_std": 18.684082325919388, "melspec_2048_cnt": 14282760192, "melspec_2048_mean": 6.768444971712967, "melspec_2048_std": 18.417922652295623, "melspec_4096_cnt": 14282760192, "melspec_4096_mean": 13.617164614990036, "melspec_4096_std": 18.08552130124525, "cqt_cnt": 9373061376, "cqt_mean": 0.46341379757927165, "cqt_std": 0.9543998080910191, "mfcc_256_cnt": 1339008768, "mfcc_256_mean": -11.681755459447485, "mfcc_256_std": 29.183186444668316, "mfcc_512_cnt": 1339008768, "mfcc_512_mean": -2.540581461792183, "mfcc_512_std": 31.93752185832081, "mfcc_1024_cnt": 1339008768, "mfcc_1024_mean": 6.606636263169779, "mfcc_1024_std": 34.151644801729624, "mfcc_2048_cnt": 1339008768, "mfcc_2048_mean": 5.281600844245184, "mfcc_2048_std": 33.12784541220003, "mfcc_4096_cnt": 1339008768, "mfcc_4096_mean": 4.7616569480166095, "mfcc_4096_std": 32.61458906894133, "chromagram_256_cnt": 1339008768, "chromagram_256_mean": 55.15596556703181, "chromagram_256_std": 73.91858278719991, "chromagram_512_cnt": 1339008768, "chromagram_512_mean": 175.73092252759895, "chromagram_512_std": 248.48485148525953, "chromagram_1024_cnt": 1339008768, "chromagram_1024_mean": 589.2947481634608, "chromagram_1024_std": 913.857929063196, "chromagram_2048_cnt": 1339008768, "chromagram_2048_mean": 2062.286388327397, "chromagram_2048_std": 3458.92657915397, "chromagram_4096_cnt": 1339008768, "chromagram_4096_mean": 7673.039107997085, "chromagram_4096_std": 13009.883158267234}
|
83 |
|
84 |
# feature extractor
|
@@ -90,40 +89,6 @@ class MusicFM25Hz(nn.Module):
|
|
90 |
self.use_rvq_target = use_rvq_target
|
91 |
|
92 |
seed = 142
|
93 |
-
if use_rvq_target:
|
94 |
-
try:
|
95 |
-
from .rvq_musicfm import ResidualVectorQuantize
|
96 |
-
|
97 |
-
except:
|
98 |
-
import sys, os
|
99 |
-
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
100 |
-
from rvq_musicfm import ResidualVectorQuantize
|
101 |
-
|
102 |
-
self.rvq = ResidualVectorQuantize(
|
103 |
-
input_dim = 128*4,
|
104 |
-
n_codebooks = 8,
|
105 |
-
codebook_size = 1024,
|
106 |
-
codebook_dim = 16,
|
107 |
-
quantizer_dropout = 0.0,
|
108 |
-
)
|
109 |
-
import os
|
110 |
-
if rvq_ckpt_path is not None and os.path.exists(rvq_ckpt_path):
|
111 |
-
state_dict = torch.load(rvq_ckpt_path, map_location="cpu")
|
112 |
-
self.rvq.load_state_dict(state_dict)
|
113 |
-
else:
|
114 |
-
print(f'Checkpoint for rvq `{rvq_ckpt_path}` not found. Using random initialization.')
|
115 |
-
|
116 |
-
else:
|
117 |
-
for feature in self.features:
|
118 |
-
for i in range(num_codebooks):
|
119 |
-
setattr(
|
120 |
-
self,
|
121 |
-
f"quantizer_{feature}", # _{i}
|
122 |
-
RandomProjectionQuantizer(
|
123 |
-
n_mels * 4, codebook_dim, codebook_size, seed=seed + i
|
124 |
-
),
|
125 |
-
)
|
126 |
-
|
127 |
# two residual convolution layers + one projection layer
|
128 |
self.conv = Conv2dSubsampling(
|
129 |
1, conv_dim, encoder_dim, strides=[2, 2], n_bands=n_mels
|
@@ -247,16 +212,8 @@ class MusicFM25Hz(nn.Module):
|
|
247 |
@torch.no_grad()
|
248 |
def tokenize(self, x):
|
249 |
out = {}
|
250 |
-
|
251 |
-
|
252 |
-
self.rvq.eval()
|
253 |
-
quantized_prompt_embeds, codes, _, commitment_loss, codebook_loss, rvq_usage = self.rvq(x[key].permute((0, 2, 1)))
|
254 |
-
out[key] = torch.cat([codes[:, idx, :] for idx in range(int(self.codebook_size//1024))], dim=-1)
|
255 |
-
else:
|
256 |
-
layer = getattr(self, "quantizer_%s" % key)
|
257 |
-
out[key] = layer(x[key])
|
258 |
-
return out
|
259 |
-
|
260 |
def get_targets(self, x):
|
261 |
x = self.preprocessing(x, features=self.features) # -> {'melspec_2048': Tensor{Size([3, 128, 3000]) cuda:0 f32}}
|
262 |
x = self.normalize(x)
|
|
|
78 |
with open(stat_path, "r") as f:
|
79 |
self.stat = json.load(f)
|
80 |
else:
|
|
|
81 |
self.stat = {"spec_256_cnt": 14394344256, "spec_256_mean": -23.34296658431829, "spec_256_std": 26.189295587132637, "spec_512_cnt": 28677104448, "spec_512_mean": -21.31267396860235, "spec_512_std": 26.52644536245769, "spec_1024_cnt": 57242624832, "spec_1024_mean": -18.852271129208273, "spec_1024_std": 26.443154583585663, "spec_2048_cnt": 114373665600, "spec_2048_mean": -15.638743433896792, "spec_2048_std": 26.115825961611545, "spec_4096_cnt": 228635747136, "spec_4096_mean": -11.715532502794836, "spec_4096_std": 25.763972210234062, "melspec_256_cnt": 14282760192, "melspec_256_mean": -26.962600400166156, "melspec_256_std": 36.13614100912126, "melspec_512_cnt": 14282760192, "melspec_512_mean": -9.108344167718862, "melspec_512_std": 24.71910937988429, "melspec_1024_cnt": 14282760192, "melspec_1024_mean": 0.37302579246531126, "melspec_1024_std": 18.684082325919388, "melspec_2048_cnt": 14282760192, "melspec_2048_mean": 6.768444971712967, "melspec_2048_std": 18.417922652295623, "melspec_4096_cnt": 14282760192, "melspec_4096_mean": 13.617164614990036, "melspec_4096_std": 18.08552130124525, "cqt_cnt": 9373061376, "cqt_mean": 0.46341379757927165, "cqt_std": 0.9543998080910191, "mfcc_256_cnt": 1339008768, "mfcc_256_mean": -11.681755459447485, "mfcc_256_std": 29.183186444668316, "mfcc_512_cnt": 1339008768, "mfcc_512_mean": -2.540581461792183, "mfcc_512_std": 31.93752185832081, "mfcc_1024_cnt": 1339008768, "mfcc_1024_mean": 6.606636263169779, "mfcc_1024_std": 34.151644801729624, "mfcc_2048_cnt": 1339008768, "mfcc_2048_mean": 5.281600844245184, "mfcc_2048_std": 33.12784541220003, "mfcc_4096_cnt": 1339008768, "mfcc_4096_mean": 4.7616569480166095, "mfcc_4096_std": 32.61458906894133, "chromagram_256_cnt": 1339008768, "chromagram_256_mean": 55.15596556703181, "chromagram_256_std": 73.91858278719991, "chromagram_512_cnt": 1339008768, "chromagram_512_mean": 175.73092252759895, "chromagram_512_std": 248.48485148525953, "chromagram_1024_cnt": 1339008768, "chromagram_1024_mean": 589.2947481634608, "chromagram_1024_std": 913.857929063196, "chromagram_2048_cnt": 1339008768, "chromagram_2048_mean": 2062.286388327397, "chromagram_2048_std": 3458.92657915397, "chromagram_4096_cnt": 1339008768, "chromagram_4096_mean": 7673.039107997085, "chromagram_4096_std": 13009.883158267234}
|
82 |
|
83 |
# feature extractor
|
|
|
89 |
self.use_rvq_target = use_rvq_target
|
90 |
|
91 |
seed = 142
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
92 |
# two residual convolution layers + one projection layer
|
93 |
self.conv = Conv2dSubsampling(
|
94 |
1, conv_dim, encoder_dim, strides=[2, 2], n_bands=n_mels
|
|
|
212 |
@torch.no_grad()
|
213 |
def tokenize(self, x):
|
214 |
out = {}
|
215 |
+
raise NotImplementedError("tokenize is not implemented")
|
216 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
217 |
def get_targets(self, x):
|
218 |
x = self.preprocessing(x, features=self.features) # -> {'melspec_2048': Tensor{Size([3, 128, 3000]) cuda:0 f32}}
|
219 |
x = self.normalize(x)
|
codeclm/tokenizer/audio_tokenizer.py
CHANGED
@@ -78,7 +78,8 @@ class AudioTokenizer(ABC, nn.Module):
|
|
78 |
vae_config: str,
|
79 |
vae_model: str,
|
80 |
device: tp.Union[torch.device, str] = 'cpu',
|
81 |
-
mode='extract'
|
|
|
82 |
) -> 'AudioTokenizer':
|
83 |
"""Instantiate a AudioTokenizer model from a given pretrained model.
|
84 |
|
@@ -91,11 +92,11 @@ class AudioTokenizer(ABC, nn.Module):
|
|
91 |
if name.split('_')[0] == 'Flow1dVAESeparate':
|
92 |
model_type = name.split('_', 1)[1]
|
93 |
logger.info("Getting pretrained compression model from semantic model %s", model_type)
|
94 |
-
model = Flow1dVAESeparate(model_type, vae_config, vae_model)
|
95 |
elif name.split('_')[0] == 'Flow1dVAE1rvq':
|
96 |
model_type = name.split('_', 1)[1]
|
97 |
logger.info("Getting pretrained compression model from semantic model %s", model_type)
|
98 |
-
model = Flow1dVAE1rvq(model_type, vae_config, vae_model)
|
99 |
else:
|
100 |
raise NotImplementedError("{} is not implemented in models/audio_tokenizer.py".format(
|
101 |
name))
|
@@ -108,12 +109,13 @@ class Flow1dVAE1rvq(AudioTokenizer):
|
|
108 |
model_type: str = "model_2_fixed.safetensors",
|
109 |
vae_config: str = "",
|
110 |
vae_model: str = "",
|
|
|
111 |
):
|
112 |
super().__init__()
|
113 |
|
114 |
from codeclm.tokenizer.Flow1dVAE.generate_1rvq import Tango
|
115 |
model_path = model_type
|
116 |
-
self.model = Tango(model_path=model_path, vae_config=vae_config, vae_model=vae_model, device=
|
117 |
print ("Successfully loaded checkpoint from:", model_path)
|
118 |
|
119 |
|
@@ -176,6 +178,15 @@ class Flow1dVAE1rvq(AudioTokenizer):
|
|
176 |
assert n <= self.total_codebooks
|
177 |
self.n_quantizers = n
|
178 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
179 |
|
180 |
class Flow1dVAESeparate(AudioTokenizer):
|
181 |
def __init__(
|
@@ -183,12 +194,13 @@ class Flow1dVAESeparate(AudioTokenizer):
|
|
183 |
model_type: str = "model_2.safetensors",
|
184 |
vae_config: str = "",
|
185 |
vae_model: str = "",
|
|
|
186 |
):
|
187 |
super().__init__()
|
188 |
|
189 |
from codeclm.tokenizer.Flow1dVAE.generate_septoken import Tango
|
190 |
model_path = model_type
|
191 |
-
self.model = Tango(model_path=model_path, vae_config=vae_config, vae_model=vae_model, device=
|
192 |
print ("Successfully loaded checkpoint from:", model_path)
|
193 |
|
194 |
|
@@ -208,9 +220,9 @@ class Flow1dVAESeparate(AudioTokenizer):
|
|
208 |
return codes_vocal, codes_bgm
|
209 |
|
210 |
@torch.no_grad()
|
211 |
-
def decode(self, codes: torch.Tensor, prompt_vocal = None, prompt_bgm = None, chunked=False):
|
212 |
wav = self.model.code2sound(codes, prompt_vocal=prompt_vocal, prompt_bgm=prompt_bgm, guidance_scale=1.5,
|
213 |
-
num_steps=50, disable_progress=False, chunked=chunked) # [B,N,T] -> [B,T]
|
214 |
return wav[None]
|
215 |
|
216 |
|
@@ -251,3 +263,14 @@ class Flow1dVAESeparate(AudioTokenizer):
|
|
251 |
assert n >= 1
|
252 |
assert n <= self.total_codebooks
|
253 |
self.n_quantizers = n
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
vae_config: str,
|
79 |
vae_model: str,
|
80 |
device: tp.Union[torch.device, str] = 'cpu',
|
81 |
+
mode='extract',
|
82 |
+
tango_device:str='cuda'
|
83 |
) -> 'AudioTokenizer':
|
84 |
"""Instantiate a AudioTokenizer model from a given pretrained model.
|
85 |
|
|
|
92 |
if name.split('_')[0] == 'Flow1dVAESeparate':
|
93 |
model_type = name.split('_', 1)[1]
|
94 |
logger.info("Getting pretrained compression model from semantic model %s", model_type)
|
95 |
+
model = Flow1dVAESeparate(model_type, vae_config, vae_model, tango_device=tango_device)
|
96 |
elif name.split('_')[0] == 'Flow1dVAE1rvq':
|
97 |
model_type = name.split('_', 1)[1]
|
98 |
logger.info("Getting pretrained compression model from semantic model %s", model_type)
|
99 |
+
model = Flow1dVAE1rvq(model_type, vae_config, vae_model, tango_device=tango_device)
|
100 |
else:
|
101 |
raise NotImplementedError("{} is not implemented in models/audio_tokenizer.py".format(
|
102 |
name))
|
|
|
109 |
model_type: str = "model_2_fixed.safetensors",
|
110 |
vae_config: str = "",
|
111 |
vae_model: str = "",
|
112 |
+
tango_device: str = "cuda"
|
113 |
):
|
114 |
super().__init__()
|
115 |
|
116 |
from codeclm.tokenizer.Flow1dVAE.generate_1rvq import Tango
|
117 |
model_path = model_type
|
118 |
+
self.model = Tango(model_path=model_path, vae_config=vae_config, vae_model=vae_model, device=tango_device)
|
119 |
print ("Successfully loaded checkpoint from:", model_path)
|
120 |
|
121 |
|
|
|
178 |
assert n <= self.total_codebooks
|
179 |
self.n_quantizers = n
|
180 |
|
181 |
+
def to(self, device=None, dtype=None, non_blocking=False):
|
182 |
+
self = super(Flow1dVAE1rvq, self).to(device, dtype, non_blocking)
|
183 |
+
self.model = self.model.to(device, dtype, non_blocking)
|
184 |
+
return self
|
185 |
+
|
186 |
+
def cuda(self, device=None):
|
187 |
+
if device is None:
|
188 |
+
device = 'cuda:0'
|
189 |
+
return super(Flow1dVAE1rvq, self).cuda(device)
|
190 |
|
191 |
class Flow1dVAESeparate(AudioTokenizer):
|
192 |
def __init__(
|
|
|
194 |
model_type: str = "model_2.safetensors",
|
195 |
vae_config: str = "",
|
196 |
vae_model: str = "",
|
197 |
+
tango_device: str = "cuda"
|
198 |
):
|
199 |
super().__init__()
|
200 |
|
201 |
from codeclm.tokenizer.Flow1dVAE.generate_septoken import Tango
|
202 |
model_path = model_type
|
203 |
+
self.model = Tango(model_path=model_path, vae_config=vae_config, vae_model=vae_model, device=tango_device)
|
204 |
print ("Successfully loaded checkpoint from:", model_path)
|
205 |
|
206 |
|
|
|
220 |
return codes_vocal, codes_bgm
|
221 |
|
222 |
@torch.no_grad()
|
223 |
+
def decode(self, codes: torch.Tensor, prompt_vocal = None, prompt_bgm = None, chunked=False, chunk_size=128):
|
224 |
wav = self.model.code2sound(codes, prompt_vocal=prompt_vocal, prompt_bgm=prompt_bgm, guidance_scale=1.5,
|
225 |
+
num_steps=50, disable_progress=False, chunked=chunked, chunk_size=chunk_size) # [B,N,T] -> [B,T]
|
226 |
return wav[None]
|
227 |
|
228 |
|
|
|
263 |
assert n >= 1
|
264 |
assert n <= self.total_codebooks
|
265 |
self.n_quantizers = n
|
266 |
+
|
267 |
+
def to(self, device=None, dtype=None, non_blocking=False):
|
268 |
+
self = super(Flow1dVAESeparate, self).to(device, dtype, non_blocking)
|
269 |
+
self.model = self.model.to(device, dtype, non_blocking)
|
270 |
+
return self
|
271 |
+
|
272 |
+
def cuda(self, device=None):
|
273 |
+
if device is None:
|
274 |
+
device = 'cuda:0'
|
275 |
+
self = super(Flow1dVAESeparate, self).cuda(device)
|
276 |
+
return self
|
codeclm/trainer/codec_song_pl.py
CHANGED
@@ -49,9 +49,7 @@ class CodecLM_PL(pl.LightningModule):
|
|
49 |
# 3) Load pretrained checkpoint (if any)
|
50 |
checkpoint = torch.load(ckpt_path, map_location='cpu')
|
51 |
missing, unexpected = self.load_state_dict(checkpoint, strict=False)
|
52 |
-
print(
|
53 |
-
print(f'-------------Unexpected--------------\n{unexpected}')
|
54 |
-
print("successfully load deepspeed pretrained model {}".format(ckpt_path))
|
55 |
# 4) Build metrics
|
56 |
self.val_steps = []
|
57 |
self.train_slide_acc = []
|
@@ -70,7 +68,6 @@ class CodecLM_PL(pl.LightningModule):
|
|
70 |
) for _ in range(self.audiolm.code_depth)])
|
71 |
|
72 |
self.epoch = 0
|
73 |
-
print("++++++++++++++++ training <song> +++++++++++++++++")
|
74 |
|
75 |
# TODO: move this part to loader
|
76 |
def generate_mask_and_end_token(self, x, sequence_lengths, end_id=16384):
|
|
|
49 |
# 3) Load pretrained checkpoint (if any)
|
50 |
checkpoint = torch.load(ckpt_path, map_location='cpu')
|
51 |
missing, unexpected = self.load_state_dict(checkpoint, strict=False)
|
52 |
+
print("successfully load pretrained model {}".format(ckpt_path))
|
|
|
|
|
53 |
# 4) Build metrics
|
54 |
self.val_steps = []
|
55 |
self.train_slide_acc = []
|
|
|
68 |
) for _ in range(self.audiolm.code_depth)])
|
69 |
|
70 |
self.epoch = 0
|
|
|
71 |
|
72 |
# TODO: move this part to loader
|
73 |
def generate_mask_and_end_token(self, x, sequence_lengths, end_id=16384):
|
codeclm/utils/offload_profiler.py
ADDED
@@ -0,0 +1,505 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch.func import functional_call
|
3 |
+
import queue
|
4 |
+
import threading
|
5 |
+
from typing import Dict, List, Any
|
6 |
+
import omegaconf
|
7 |
+
from pydantic import BaseModel, validator
|
8 |
+
from typing import Optional
|
9 |
+
from functools import wraps
|
10 |
+
|
11 |
+
def _callable_once(func):
|
12 |
+
@wraps(func)
|
13 |
+
def wrapper(self, *args, **kwargs):
|
14 |
+
method_called_flag = f"_called_once_{func.__name__}"
|
15 |
+
if getattr(self, method_called_flag, False):
|
16 |
+
raise RuntimeError(f"{func.__name__} can only be called once.")
|
17 |
+
setattr(self, method_called_flag, True)
|
18 |
+
return func(self, *args, **kwargs)
|
19 |
+
return wrapper
|
20 |
+
|
21 |
+
class OffloadCleanCacheWrapperParam(BaseModel):
|
22 |
+
module: Any
|
23 |
+
method_name: str
|
24 |
+
diff_mem_gb_thre: float
|
25 |
+
|
26 |
+
class OffloadParam(BaseModel):
|
27 |
+
offload_module: Any
|
28 |
+
cpu_mem_gb: float
|
29 |
+
pre_copy_step: Optional[int] = None
|
30 |
+
clean_cache_after_forward: Optional[bool] = None
|
31 |
+
dtype: Optional[str] = None
|
32 |
+
offload_layer_dict: Dict[str, int] = {}
|
33 |
+
ignore_layer_list: List[str] = []
|
34 |
+
clean_cache_wrapper: Optional[OffloadCleanCacheWrapperParam] = None
|
35 |
+
debug: Optional[bool] = None
|
36 |
+
|
37 |
+
@validator('dtype')
|
38 |
+
def parse_dtype(cls, value):
|
39 |
+
if value is None:
|
40 |
+
return None
|
41 |
+
dtype_map = {
|
42 |
+
'torch.float16': torch.float16,
|
43 |
+
'torch.float32': torch.float32,
|
44 |
+
'torch.float64': torch.float64,
|
45 |
+
'torch.int64': torch.int64,
|
46 |
+
}
|
47 |
+
if value not in dtype_map:
|
48 |
+
raise ValueError(f"Unsupported dtype: {value}")
|
49 |
+
return dtype_map[value]
|
50 |
+
|
51 |
+
def init_param_dict(self):
|
52 |
+
param_dict = {}
|
53 |
+
param_dict['cpu_mem_gb'] = self.cpu_mem_gb
|
54 |
+
if self.pre_copy_step is not None:
|
55 |
+
param_dict['pre_copy_step'] = self.pre_copy_step
|
56 |
+
if self.clean_cache_after_forward is not None:
|
57 |
+
param_dict['clean_cache_after_forward'] = self.clean_cache_after_forward
|
58 |
+
if self.debug is not None:
|
59 |
+
param_dict['debug'] = self.debug
|
60 |
+
|
61 |
+
return param_dict
|
62 |
+
|
63 |
+
def offload_layer_param_dict(self):
|
64 |
+
param_dict = {}
|
65 |
+
param_dict['module'] = self.offload_module
|
66 |
+
param_dict['offload_layer_dict'] = self.offload_layer_dict
|
67 |
+
param_dict['ignore_layer_list'] = self.ignore_layer_list
|
68 |
+
param_dict['dtype'] = self.dtype
|
69 |
+
|
70 |
+
return param_dict
|
71 |
+
|
72 |
+
def clean_cache_param_dict(self):
|
73 |
+
param_dict = {}
|
74 |
+
if self.clean_cache_wrapper is not None:
|
75 |
+
param_dict['module'] = self.clean_cache_wrapper.module
|
76 |
+
param_dict['method_name'] = self.clean_cache_wrapper.method_name
|
77 |
+
param_dict['diff_mem_gb_thre'] = self.clean_cache_wrapper.diff_mem_gb_thre
|
78 |
+
|
79 |
+
return param_dict
|
80 |
+
|
81 |
+
@staticmethod
|
82 |
+
def recursive_print(model, indent=0):
|
83 |
+
for field_name, field_info in model.__fields__.items():
|
84 |
+
field_value = getattr(model, field_name)
|
85 |
+
print(" " * indent + f"{field_name}:")
|
86 |
+
|
87 |
+
if issubclass(type(field_value), BaseModel):
|
88 |
+
print(" " * (indent + 2) + f"--- Nested model: {field_value.__class__.__name__}")
|
89 |
+
OffloadParam.recursive_print(field_value, indent + 4)
|
90 |
+
else:
|
91 |
+
print(" " * (indent + 2) + f"class: {field_value.__class__.__name__}")
|
92 |
+
if isinstance(field_value, torch.nn.Module):
|
93 |
+
pass
|
94 |
+
else:
|
95 |
+
print(" " * (indent + 2) + f"value: {field_value}")
|
96 |
+
|
97 |
+
def show(self):
|
98 |
+
print("-"*20 + "[OffloadParam]" + "-"*20)
|
99 |
+
OffloadParam.recursive_print(self)
|
100 |
+
print("-"*40)
|
101 |
+
|
102 |
+
|
103 |
+
class OffloadParamParse:
|
104 |
+
def __init__(self):
|
105 |
+
pass
|
106 |
+
|
107 |
+
@staticmethod
|
108 |
+
def _get_model(root_model: torch.nn.Module, model_dir: str):
|
109 |
+
assert(model_dir.startswith("self")), f"model_dir {model_dir} must startswith `self`"
|
110 |
+
model = root_model
|
111 |
+
for layer in model_dir.split('.'):
|
112 |
+
if layer == "self":
|
113 |
+
continue
|
114 |
+
assert(hasattr(model, layer)), f"model not has layer [{layer}]!"
|
115 |
+
model = getattr(model, layer)
|
116 |
+
return model
|
117 |
+
|
118 |
+
@staticmethod
|
119 |
+
def parse_config(root_model: torch.nn.Module, cfg: omegaconf.DictConfig)->OffloadParam:
|
120 |
+
assert(hasattr(cfg, "offload_module") and hasattr(cfg, "cpu_mem_gb") and hasattr(cfg, "dtype"))
|
121 |
+
|
122 |
+
offload_module = OffloadParamParse._get_model(root_model, cfg.offload_module)
|
123 |
+
cpu_mem_gb = cfg.cpu_mem_gb
|
124 |
+
dtype = cfg.dtype
|
125 |
+
|
126 |
+
pre_copy_step = cfg.pre_copy_step \
|
127 |
+
if hasattr(cfg, "pre_copy_step") else None
|
128 |
+
|
129 |
+
clean_cache_after_forward = cfg.clean_cache_after_forward \
|
130 |
+
if hasattr(cfg, "clean_cache_after_forward") else None
|
131 |
+
|
132 |
+
offload_layer_dict = {k: v for k, v in cfg.offload_layer_dict.items()} \
|
133 |
+
if hasattr(cfg, "offload_layer_dict") else {}
|
134 |
+
|
135 |
+
ignore_layer_list = cfg.ignore_layer_list \
|
136 |
+
if hasattr(cfg, "ignore_layer_list") else []
|
137 |
+
|
138 |
+
debug = cfg.debug if hasattr(cfg, "debug") else None
|
139 |
+
|
140 |
+
clean_cache_wrapper = None
|
141 |
+
if hasattr(cfg, "clean_cache_wrapper"):
|
142 |
+
clean_cache_cfg = cfg.clean_cache_wrapper
|
143 |
+
cc_module = OffloadParamParse._get_model(root_model, clean_cache_cfg.module)
|
144 |
+
cc_method_name = clean_cache_cfg.method_name
|
145 |
+
diff_mem_gb_thre = clean_cache_cfg.diff_mem_gb_thre
|
146 |
+
clean_cache_wrapper = OffloadCleanCacheWrapperParam(
|
147 |
+
module=cc_module,
|
148 |
+
method_name=cc_method_name,
|
149 |
+
diff_mem_gb_thre=diff_mem_gb_thre)
|
150 |
+
|
151 |
+
return OffloadParam(
|
152 |
+
offload_module=offload_module,
|
153 |
+
cpu_mem_gb=cpu_mem_gb,
|
154 |
+
pre_copy_step=pre_copy_step,
|
155 |
+
clean_cache_after_forward=clean_cache_after_forward,
|
156 |
+
dtype=dtype,
|
157 |
+
offload_layer_dict=offload_layer_dict,
|
158 |
+
ignore_layer_list=ignore_layer_list,
|
159 |
+
clean_cache_wrapper=clean_cache_wrapper,
|
160 |
+
debug=debug
|
161 |
+
)
|
162 |
+
|
163 |
+
|
164 |
+
class LayerParamStruct:
|
165 |
+
def __init__(self):
|
166 |
+
self.count = 0
|
167 |
+
self.device_state = None
|
168 |
+
|
169 |
+
|
170 |
+
class OffloadProfiler:
|
171 |
+
def __init__(self, device_index=0, cpu_mem_gb=-1, pre_copy_step=1, clean_cache_after_forward=False, debug=False):
|
172 |
+
self.clean_cache_after_forward = clean_cache_after_forward
|
173 |
+
self.cpu_mem_gb = cpu_mem_gb
|
174 |
+
self.cpu_mem_b_count = 0
|
175 |
+
self.device_index = device_index
|
176 |
+
self.execution_order = []
|
177 |
+
self.execution_order_idx = {}
|
178 |
+
self.pin_memory = False
|
179 |
+
test_data = torch.rand(1,1, device='cpu')
|
180 |
+
pin_data = test_data.pin_memory()
|
181 |
+
self.pin_memory = pin_data.is_pinned()
|
182 |
+
print(f"pin:{self.pin_memory}")
|
183 |
+
self.copy_stream = torch.cuda.Stream()
|
184 |
+
self.copy_queue = queue.Queue()
|
185 |
+
self.layer_param:Dict[str, LayerParamStruct] = {}
|
186 |
+
self.model_map = {}
|
187 |
+
self.stop_flag = False
|
188 |
+
self.copy_condition = threading.Condition()
|
189 |
+
self.queue_condition = threading.Condition()
|
190 |
+
self.mem_line_b = 0
|
191 |
+
|
192 |
+
self.copy_thread = threading.Thread(target=self._copy_thread_fun)
|
193 |
+
self.copy_thread.daemon = True
|
194 |
+
self.copy_thread.start()
|
195 |
+
|
196 |
+
self.cur_copy_idx = 0
|
197 |
+
self.execute_over = False
|
198 |
+
self.pre_copy_step = pre_copy_step
|
199 |
+
|
200 |
+
self.tmp_state_list = []
|
201 |
+
self.tmp_state_idx = 0
|
202 |
+
for i in range(pre_copy_step + 2):
|
203 |
+
self.tmp_state_list.append(None)
|
204 |
+
|
205 |
+
self.debug = debug
|
206 |
+
|
207 |
+
def stop(self):
|
208 |
+
self.stop_flag = True
|
209 |
+
with self.queue_condition:
|
210 |
+
self.queue_condition.notify()
|
211 |
+
self.copy_thread.join()
|
212 |
+
|
213 |
+
del self.layer_param
|
214 |
+
del self.model_map
|
215 |
+
del self.copy_stream
|
216 |
+
|
217 |
+
def _copy_thread_fun(self):
|
218 |
+
while self.stop_flag == False:
|
219 |
+
layer_name = "--"
|
220 |
+
with self.queue_condition:
|
221 |
+
while self.copy_queue.qsize() == 0 and self.stop_flag == False:
|
222 |
+
self.queue_condition.wait()
|
223 |
+
if self.stop_flag == True:
|
224 |
+
break
|
225 |
+
layer_name = self.copy_queue.get()
|
226 |
+
with torch.cuda.stream(self.copy_stream):
|
227 |
+
if layer_name in self.model_map:
|
228 |
+
model = self.model_map[layer_name]
|
229 |
+
self.tmp_state_list[self.tmp_state_idx] = {
|
230 |
+
k: v.to(torch.device(f"cuda:{self.device_index}"), non_blocking=False)
|
231 |
+
for k, v in model.state_dict().items()
|
232 |
+
}
|
233 |
+
self.copy_stream.synchronize()
|
234 |
+
|
235 |
+
device_state = self.tmp_state_list[self.tmp_state_idx]
|
236 |
+
self.tmp_state_idx = (self.tmp_state_idx + 1) % len(self.tmp_state_list)
|
237 |
+
|
238 |
+
with self.copy_condition:
|
239 |
+
if layer_name in self.layer_param:
|
240 |
+
self.layer_param[layer_name].count += 1
|
241 |
+
else:
|
242 |
+
self.layer_param[layer_name] = LayerParamStruct()
|
243 |
+
self.layer_param[layer_name].count = 1
|
244 |
+
self.layer_param[layer_name].device_state = device_state
|
245 |
+
self.copy_condition.notify()
|
246 |
+
else:
|
247 |
+
print(f"get model error! {layer_name}")
|
248 |
+
print("copy thread stop..")
|
249 |
+
|
250 |
+
def _get_new_step_copy_begin_end(self, tag_name):
|
251 |
+
|
252 |
+
pre_copy_step = self.pre_copy_step
|
253 |
+
pre_copy_step = min(pre_copy_step, len(self.execution_order) // 2)
|
254 |
+
|
255 |
+
cur_exe_idx = self.execution_order_idx[tag_name]
|
256 |
+
copy_begin = self.cur_copy_idx
|
257 |
+
copy_end = cur_exe_idx + pre_copy_step + 1
|
258 |
+
if copy_end - copy_begin > len(self.execution_order):
|
259 |
+
copy_end %= len(self.execution_order)
|
260 |
+
if copy_end - copy_begin > pre_copy_step + 1 or copy_end - copy_begin < 0:
|
261 |
+
# jump
|
262 |
+
self.cur_copy_idx = cur_exe_idx
|
263 |
+
copy_begin, copy_end = self._get_new_step_copy_begin_end(tag_name=tag_name)
|
264 |
+
return copy_begin, copy_end
|
265 |
+
|
266 |
+
def make_forward_wrapper(self, module, tag_name, ignore_layer_list=[]):
|
267 |
+
original_forward = module.forward
|
268 |
+
layer_param_size = 0
|
269 |
+
for name, param in module.named_parameters():
|
270 |
+
layer_param_size += param.data.numel() * param.data.element_size() / 1024 / 1024 #MB
|
271 |
+
|
272 |
+
taget_cpu_mem_b = self.cpu_mem_gb * 1024 * 1024 * 1024
|
273 |
+
offload = False
|
274 |
+
for name, param in module.named_parameters():
|
275 |
+
p_name = f"{tag_name}.{name}" if tag_name else name
|
276 |
+
for i_layer in ignore_layer_list:
|
277 |
+
if p_name.startswith(i_layer):
|
278 |
+
if self.debug:
|
279 |
+
print(f"ignore layer param: {p_name}")
|
280 |
+
continue
|
281 |
+
|
282 |
+
if taget_cpu_mem_b >= 0 and self.cpu_mem_b_count >= taget_cpu_mem_b:
|
283 |
+
break
|
284 |
+
cpu_data = torch.empty_strided(size=param.data.size(),
|
285 |
+
stride=param.data.stride(),
|
286 |
+
dtype=param.data.dtype,
|
287 |
+
layout=param.data.layout,
|
288 |
+
device='cpu',
|
289 |
+
pin_memory=self.pin_memory)
|
290 |
+
cpu_data.copy_(param.data)
|
291 |
+
param.data = cpu_data
|
292 |
+
|
293 |
+
param_size = param.data.numel() * param.data.element_size()
|
294 |
+
self.cpu_mem_b_count += param_size
|
295 |
+
offload = True
|
296 |
+
if self.debug:
|
297 |
+
print(f"layer: {tag_name}, type: {module.__class__.__name__}, size(MB): {layer_param_size}, offload: {offload}, sum_offload_size(MB): {self.cpu_mem_b_count/1024/1024}")
|
298 |
+
|
299 |
+
if offload:
|
300 |
+
copy_condition = self.copy_condition
|
301 |
+
queue_condition = self.queue_condition
|
302 |
+
copy_queue = self.copy_queue
|
303 |
+
layer_param = self.layer_param
|
304 |
+
def forward_wrapper(*args, **kwargs):
|
305 |
+
module.forward = original_forward
|
306 |
+
|
307 |
+
execute_over = False if tag_name not in self.execution_order_idx else True
|
308 |
+
if execute_over == False:
|
309 |
+
self.model_map[tag_name] = module
|
310 |
+
self.execution_order.append(tag_name)
|
311 |
+
self.execution_order_idx[tag_name] = len(self.execution_order) - 1
|
312 |
+
copy_queue.put(tag_name)
|
313 |
+
with queue_condition:
|
314 |
+
queue_condition.notify()
|
315 |
+
else:
|
316 |
+
|
317 |
+
copy_begin, copy_end = self._get_new_step_copy_begin_end(tag_name=tag_name)
|
318 |
+
if copy_end > copy_begin:
|
319 |
+
for idx in range(copy_begin, copy_end):
|
320 |
+
idx = idx % len(self.execution_order)
|
321 |
+
copy_tag_name = self.execution_order[idx]
|
322 |
+
copy_queue.put(copy_tag_name)
|
323 |
+
with queue_condition:
|
324 |
+
queue_condition.notify()
|
325 |
+
|
326 |
+
self.cur_copy_idx = copy_end % len(self.execution_order)
|
327 |
+
|
328 |
+
run_state = None
|
329 |
+
with self.copy_condition:
|
330 |
+
while tag_name not in self.layer_param:
|
331 |
+
copy_condition.wait()
|
332 |
+
run_state = self.layer_param[tag_name].device_state
|
333 |
+
self.layer_param[tag_name].count -= 1
|
334 |
+
|
335 |
+
module.eval()
|
336 |
+
with torch.no_grad():
|
337 |
+
output = functional_call(module, run_state, args=args, kwargs=kwargs)
|
338 |
+
with self.copy_condition:
|
339 |
+
if self.layer_param[tag_name].count == 0:
|
340 |
+
del self.layer_param[tag_name]
|
341 |
+
diff_mem_b_thre = 1 * (1024 ** 3)
|
342 |
+
if self.clean_cache_after_forward:
|
343 |
+
reserved = torch.cuda.memory_reserved()
|
344 |
+
if reserved > self.mem_line_b:
|
345 |
+
torch.cuda.empty_cache()
|
346 |
+
cur_reserved = torch.cuda.memory_reserved()
|
347 |
+
diff_mem = reserved - cur_reserved
|
348 |
+
if diff_mem > diff_mem_b_thre:
|
349 |
+
self.mem_line_b = cur_reserved + (reserved - cur_reserved) / 2 + 10
|
350 |
+
else:
|
351 |
+
self.mem_line_b = reserved + 10
|
352 |
+
if self.debug:
|
353 |
+
print(f"child mem line update, clean cache:{reserved/1024/1024}, cur mem: {cur_reserved/1024/1024} new limit: {self.mem_line_b / 1024 / 1024}, child name: {tag_name}")
|
354 |
+
|
355 |
+
module.forward = forward_wrapper
|
356 |
+
return output
|
357 |
+
module.forward = forward_wrapper
|
358 |
+
|
359 |
+
torch.cuda.empty_cache()
|
360 |
+
return module
|
361 |
+
|
362 |
+
def reset_empty_cache_mem_line(self):
|
363 |
+
self.mem_line_b = 0
|
364 |
+
torch.cuda.empty_cache()
|
365 |
+
|
366 |
+
def clean_cache_wrapper(self, module, method_name='', diff_mem_gb_thre=1):
|
367 |
+
if not hasattr(module, method_name) or not callable(getattr(module, method_name)):
|
368 |
+
print(f"no this method {method_name}")
|
369 |
+
return module
|
370 |
+
|
371 |
+
original_fun = getattr(module, method_name)
|
372 |
+
diff_mem_b_thre = diff_mem_gb_thre * (1024 ** 3)
|
373 |
+
self.reset_empty_cache_mem_line()
|
374 |
+
|
375 |
+
def clean_wrapper(*args, **kwargs):
|
376 |
+
setattr(module, method_name, original_fun)
|
377 |
+
output = original_fun(*args, **kwargs)
|
378 |
+
reserved = torch.cuda.memory_reserved()
|
379 |
+
if reserved > self.mem_line_b:
|
380 |
+
torch.cuda.empty_cache()
|
381 |
+
cur_reserved = torch.cuda.memory_reserved()
|
382 |
+
diff_mem = reserved - cur_reserved
|
383 |
+
if diff_mem > diff_mem_b_thre:
|
384 |
+
self.mem_line_b = cur_reserved + (reserved - cur_reserved) / 2 + 10
|
385 |
+
else:
|
386 |
+
self.mem_line_b = reserved + 10
|
387 |
+
|
388 |
+
if self.debug:
|
389 |
+
print(f"mem line update, clean cache:{reserved/1024/1024}, cur mem: {cur_reserved/1024/1024} new limit: {self.mem_line_b / 1024 / 1024}")
|
390 |
+
setattr(module, method_name, clean_wrapper)
|
391 |
+
return output
|
392 |
+
|
393 |
+
setattr(module, method_name, clean_wrapper)
|
394 |
+
return module
|
395 |
+
|
396 |
+
@_callable_once
|
397 |
+
def offload_layer(self, module, offload_layer_dict={}, ignore_layer_list=[], dtype:torch.dtype = None):
|
398 |
+
return self._offload_layer(
|
399 |
+
module=module,
|
400 |
+
tag="",
|
401 |
+
offload_layer_dict=offload_layer_dict,
|
402 |
+
ignore_layer_list=ignore_layer_list,
|
403 |
+
dtype=dtype
|
404 |
+
)
|
405 |
+
|
406 |
+
def _offload_layer(self, module, tag="", offload_layer_dict={}, ignore_layer_list=[], dtype:torch.dtype = None):
|
407 |
+
"""
|
408 |
+
Offload specific layers of a PyTorch model to a specified depth.
|
409 |
+
A model can only be offloaded once.
|
410 |
+
|
411 |
+
Args:
|
412 |
+
module (torch.nn.Module):
|
413 |
+
The PyTorch model containing the layers to offload. This is the model that will be modified in place.
|
414 |
+
|
415 |
+
tag (str, optional):
|
416 |
+
A string identifier for the model.
|
417 |
+
Default is an empty string.
|
418 |
+
|
419 |
+
offload_layer_dict (dict, optional):
|
420 |
+
A dictionary where keys are layer names and values represent the depth at which the offloading should occur.
|
421 |
+
For example,
|
422 |
+
```offload_layer_dict = {'cfm_wrapper': 5, 'hubert': 4}``` means that the `cfm_wrapper` layer should
|
423 |
+
be offloaded at depth 5, and the `hubert` layer should be offloaded at depth 4.
|
424 |
+
Default is an empty dictionary.
|
425 |
+
|
426 |
+
ignore_layer_list (list, optional):
|
427 |
+
A list of layer names or parameter identifiers to be ignored during the offloading process.
|
428 |
+
Layers in this list will not be offloaded, even if they are present in the `offload_layer_dict`.
|
429 |
+
For example,
|
430 |
+
```ignore_layer_list = ['cfm_wrapper.estimator.h', 'cfm_wrapper.estimator.adaln_single']```
|
431 |
+
means that layers starting with `cfm_wrapper.estimator.h` or 'cfm_wrapper.estimator.adaln_single' will not be offload.
|
432 |
+
Default is an empty list.
|
433 |
+
|
434 |
+
dtype (torch.dtype, optional):
|
435 |
+
The data type (e.g., `torch.float16`, `torch.float32`) to which the offloaded layers should be converted.
|
436 |
+
If `None`, the data type of the layers will remain unchanged. Default is `None`.
|
437 |
+
|
438 |
+
Returns:
|
439 |
+
None
|
440 |
+
"""
|
441 |
+
for p in module._parameters.values():
|
442 |
+
if p is not None:
|
443 |
+
p.data = p.data.to(torch.device(f"cuda:{self.device_index}"))
|
444 |
+
if dtype is not None:
|
445 |
+
p.data = p.data.to(dtype)
|
446 |
+
for b in module._buffers.values():
|
447 |
+
if b is not None:
|
448 |
+
b.data = b.data.to(torch.device(f"cuda:{self.device_index}"))
|
449 |
+
if dtype is not None:
|
450 |
+
b.data = b.data.to(dtype)
|
451 |
+
for attr_name, attr in module.__dict__.items():
|
452 |
+
if isinstance(attr, torch.Tensor) and not attr_name.startswith('_'):
|
453 |
+
attr.data = attr.data.to(torch.device(f"cuda:{self.device_index}"))
|
454 |
+
if dtype is not None:
|
455 |
+
attr.data = attr.data.to(dtype)
|
456 |
+
|
457 |
+
for name, child in module.named_children():
|
458 |
+
current_tag = f"{tag}.{name}" if tag else name
|
459 |
+
child = child.to(torch.device(f"cuda:{self.device_index}"))
|
460 |
+
if dtype is not None:
|
461 |
+
child = child.to(dtype)
|
462 |
+
|
463 |
+
torch.cuda.empty_cache()
|
464 |
+
setattr(module, name, child)
|
465 |
+
pre_name = current_tag.split('.')[0]
|
466 |
+
if pre_name not in offload_layer_dict:
|
467 |
+
param_size = 0
|
468 |
+
for p in child.parameters():
|
469 |
+
param_size += p.data.numel() * p.data.element_size()
|
470 |
+
param_size = param_size / 1024 / 1024
|
471 |
+
if self.debug:
|
472 |
+
print(f"not offload layer {current_tag}, size: {param_size}MB")
|
473 |
+
continue
|
474 |
+
|
475 |
+
has_children = any(child.named_children())
|
476 |
+
layer_count = current_tag.count('.') + 1
|
477 |
+
|
478 |
+
layer_deep = offload_layer_dict[pre_name]
|
479 |
+
if layer_count >= layer_deep:
|
480 |
+
has_children = False
|
481 |
+
|
482 |
+
if has_children:
|
483 |
+
self._offload_layer(module=child,
|
484 |
+
tag=current_tag,
|
485 |
+
offload_layer_dict=offload_layer_dict,
|
486 |
+
ignore_layer_list=ignore_layer_list,
|
487 |
+
dtype=dtype)
|
488 |
+
continue
|
489 |
+
|
490 |
+
ignore = False
|
491 |
+
for i_layer in ignore_layer_list:
|
492 |
+
if current_tag.startswith(i_layer):
|
493 |
+
ignore = True
|
494 |
+
if self.debug:
|
495 |
+
print(f"ignore layer offload: {current_tag}")
|
496 |
+
break
|
497 |
+
|
498 |
+
if hasattr(child, "forward") and not ignore:
|
499 |
+
child = self.make_forward_wrapper(
|
500 |
+
child, current_tag, ignore_layer_list=ignore_layer_list
|
501 |
+
)
|
502 |
+
return module
|
503 |
+
|
504 |
+
def get_execution_order(self):
|
505 |
+
return self.execution_order
|
download.py
CHANGED
@@ -7,7 +7,7 @@ def download_model(local_dir):
|
|
7 |
downloaded_path = snapshot_download(
|
8 |
repo_id=repo_id,
|
9 |
local_dir=local_dir,
|
10 |
-
revision="
|
11 |
token=os.environ.get("HF_TOKEN"),
|
12 |
ignore_patterns=['.git*']
|
13 |
)
|
|
|
7 |
downloaded_path = snapshot_download(
|
8 |
repo_id=repo_id,
|
9 |
local_dir=local_dir,
|
10 |
+
revision="647f0a5",
|
11 |
token=os.environ.get("HF_TOKEN"),
|
12 |
ignore_patterns=['.git*']
|
13 |
)
|
generate.py
CHANGED
@@ -1,5 +1,7 @@
|
|
|
|
1 |
import sys
|
2 |
import os
|
|
|
3 |
|
4 |
import time
|
5 |
import json
|
@@ -7,11 +9,13 @@ import torch
|
|
7 |
import torchaudio
|
8 |
import numpy as np
|
9 |
from omegaconf import OmegaConf
|
10 |
-
|
|
|
11 |
from codeclm.trainer.codec_song_pl import CodecLM_PL
|
12 |
from codeclm.models import CodecLM
|
13 |
from third_party.demucs.models.pretrained import get_model_from_yaml
|
14 |
|
|
|
15 |
auto_prompt_type = ['Pop', 'R&B', 'Dance', 'Jazz', 'Folk', 'Rock', 'Chinese Style', 'Chinese Tradition', 'Metal', 'Reggae', 'Chinese Opera', 'Auto']
|
16 |
|
17 |
class Separator:
|
@@ -34,8 +38,6 @@ class Separator:
|
|
34 |
a = torchaudio.functional.resample(a, fs, 48000)
|
35 |
if a.shape[-1] >= 48000*10:
|
36 |
a = a[..., :48000*10]
|
37 |
-
else:
|
38 |
-
a = torch.cat([a, a], -1)
|
39 |
return a[:, 0:48000*10]
|
40 |
|
41 |
def run(self, audio_path, output_dir='tmp', ext=".flac"):
|
@@ -59,38 +61,146 @@ class Separator:
|
|
59 |
return full_audio, vocal_audio, bgm_audio
|
60 |
|
61 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
62 |
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
OmegaConf.register_new_resolver("get_fname", lambda: os.path.splitext(os.path.basename(sys.argv[1]))[0])
|
68 |
-
OmegaConf.register_new_resolver("load_yaml", lambda x: list(OmegaConf.load(x)))
|
69 |
-
np.random.seed(int(time.time()))
|
70 |
-
ckpt_path = sys.argv[1]
|
71 |
-
input_jsonl = sys.argv[2]
|
72 |
-
save_dir = sys.argv[3]
|
73 |
-
gen_type = sys.argv[4] if len(sys.argv) > 4 else "all"
|
74 |
cfg_path = os.path.join(ckpt_path, 'config.yaml')
|
75 |
ckpt_path = os.path.join(ckpt_path, 'model.pt')
|
76 |
cfg = OmegaConf.load(cfg_path)
|
|
|
|
|
77 |
cfg.mode = 'inference'
|
78 |
max_duration = cfg.max_dur
|
|
|
79 |
|
80 |
-
# Define model or load pretrained model
|
81 |
-
model_light = CodecLM_PL(cfg, ckpt_path)
|
82 |
|
83 |
-
model_light = model_light.eval().cuda()
|
84 |
-
model_light.audiolm.cfg = cfg
|
85 |
-
model = CodecLM(name = "tmp",
|
86 |
-
lm = model_light.audiolm,
|
87 |
-
audiotokenizer = model_light.audio_tokenizer,
|
88 |
-
max_duration = max_duration,
|
89 |
-
seperate_tokenizer = model_light.seperate_tokenizer,
|
90 |
-
)
|
91 |
separator = Separator()
|
92 |
auto_prompt = torch.load('ckpt/prompt.pt')
|
|
|
|
|
93 |
merge_prompt = [item for sublist in auto_prompt.values() for item in sublist]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
94 |
cfg_coef = 1.5 #25
|
95 |
temp = 0.9
|
96 |
top_k = 50
|
@@ -104,21 +214,135 @@ if __name__ == "__main__":
|
|
104 |
os.makedirs(save_dir + "/audios", exist_ok=True)
|
105 |
os.makedirs(save_dir + "/jsonl", exist_ok=True)
|
106 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
107 |
with open(input_jsonl, "r") as fp:
|
108 |
lines = fp.readlines()
|
109 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
110 |
new_items = []
|
111 |
for line in lines:
|
112 |
item = json.loads(line)
|
113 |
target_wav_name = f"{save_dir}/audios/{item['idx']}.flac"
|
114 |
-
lyric = item["gt_lyric"]
|
115 |
-
descriptions = item["descriptions"] if "descriptions" in item else None
|
116 |
# get prompt audio
|
117 |
if "prompt_audio_path" in item:
|
118 |
assert os.path.exists(item['prompt_audio_path']), f"prompt_audio_path {item['prompt_audio_path']} not found"
|
119 |
assert 'auto_prompt_audio_type' not in item, f"auto_prompt_audio_type and prompt_audio_path cannot be used together"
|
120 |
-
|
121 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
122 |
elif "auto_prompt_audio_type" in item:
|
123 |
assert item["auto_prompt_audio_type"] in auto_prompt_type, f"auto_prompt_audio_type {item['auto_prompt_audio_type']} not found"
|
124 |
if item["auto_prompt_audio_type"] == "Auto":
|
@@ -134,6 +358,86 @@ if __name__ == "__main__":
|
|
134 |
vocal_wav = None
|
135 |
bgm_wav = None
|
136 |
melody_is_wav = True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
137 |
|
138 |
generate_inp = {
|
139 |
'lyrics': [lyric.replace(" ", " ")],
|
@@ -143,25 +447,119 @@ if __name__ == "__main__":
|
|
143 |
'bgm_wavs': bgm_wav,
|
144 |
'melody_is_wav': melody_is_wav,
|
145 |
}
|
146 |
-
start_time = time.time()
|
147 |
with torch.autocast(device_type="cuda", dtype=torch.float16):
|
148 |
-
|
149 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
150 |
|
|
|
151 |
with torch.no_grad():
|
152 |
-
if
|
153 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
154 |
else:
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
163 |
|
|
|
|
|
|
|
164 |
src_jsonl_name = os.path.split(input_jsonl)[-1]
|
165 |
with open(f"{save_dir}/jsonl/{src_jsonl_name}.jsonl", "w", encoding='utf-8') as fw:
|
166 |
for item in new_items:
|
167 |
fw.writelines(json.dumps(item, ensure_ascii=False)+"\n")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from hmac import new
|
2 |
import sys
|
3 |
import os
|
4 |
+
import argparse
|
5 |
|
6 |
import time
|
7 |
import json
|
|
|
9 |
import torchaudio
|
10 |
import numpy as np
|
11 |
from omegaconf import OmegaConf
|
12 |
+
from codeclm.models import builders
|
13 |
+
import gc
|
14 |
from codeclm.trainer.codec_song_pl import CodecLM_PL
|
15 |
from codeclm.models import CodecLM
|
16 |
from third_party.demucs.models.pretrained import get_model_from_yaml
|
17 |
|
18 |
+
|
19 |
auto_prompt_type = ['Pop', 'R&B', 'Dance', 'Jazz', 'Folk', 'Rock', 'Chinese Style', 'Chinese Tradition', 'Metal', 'Reggae', 'Chinese Opera', 'Auto']
|
20 |
|
21 |
class Separator:
|
|
|
38 |
a = torchaudio.functional.resample(a, fs, 48000)
|
39 |
if a.shape[-1] >= 48000*10:
|
40 |
a = a[..., :48000*10]
|
|
|
|
|
41 |
return a[:, 0:48000*10]
|
42 |
|
43 |
def run(self, audio_path, output_dir='tmp', ext=".flac"):
|
|
|
61 |
return full_audio, vocal_audio, bgm_audio
|
62 |
|
63 |
|
64 |
+
def parse_args():
|
65 |
+
parser = argparse.ArgumentParser(description='Song Generation Script')
|
66 |
+
|
67 |
+
# 必需参数
|
68 |
+
parser.add_argument('--ckpt_path', type=str, required=True,
|
69 |
+
help='Path to the checkpoint directory containing config.yaml and model.pt')
|
70 |
+
parser.add_argument('--input_jsonl', type=str, required=True,
|
71 |
+
help='Path to input JSONL file containing generation tasks')
|
72 |
+
parser.add_argument('--save_dir', type=str, required=True,
|
73 |
+
help='Directory to save generated audio files and results')
|
74 |
+
# 可选参数
|
75 |
+
parser.add_argument('--generate_type', type=str, default='mixed',
|
76 |
+
help='Type of generation: "vocal" or "bgm" or "separate" or "mixed" (default: "mixed")')
|
77 |
+
parser.add_argument('--use_flash_attn', action='store_true',
|
78 |
+
help='Whether to use flash attention (default: False)')
|
79 |
+
parser.add_argument('--low_mem', action='store_true',
|
80 |
+
help='Whether to use low memory mode (default: False)')
|
81 |
+
return parser.parse_args()
|
82 |
|
83 |
+
def generate(args):
|
84 |
+
ckpt_path = args.ckpt_path
|
85 |
+
input_jsonl = args.input_jsonl
|
86 |
+
save_dir = args.save_dir
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
87 |
cfg_path = os.path.join(ckpt_path, 'config.yaml')
|
88 |
ckpt_path = os.path.join(ckpt_path, 'model.pt')
|
89 |
cfg = OmegaConf.load(cfg_path)
|
90 |
+
cfg.lm.use_flash_attn_2 = args.use_flash_attn
|
91 |
+
print(f"use_flash_attn: {args.use_flash_attn}")
|
92 |
cfg.mode = 'inference'
|
93 |
max_duration = cfg.max_dur
|
94 |
+
gen_type = args.generate_type
|
95 |
|
|
|
|
|
96 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
97 |
separator = Separator()
|
98 |
auto_prompt = torch.load('ckpt/prompt.pt')
|
99 |
+
audio_tokenizer = builders.get_audio_tokenizer_model(cfg.audio_tokenizer_checkpoint, cfg)
|
100 |
+
audio_tokenizer = audio_tokenizer.eval().cuda()
|
101 |
merge_prompt = [item for sublist in auto_prompt.values() for item in sublist]
|
102 |
+
with open(input_jsonl, "r") as fp:
|
103 |
+
lines = fp.readlines()
|
104 |
+
|
105 |
+
|
106 |
+
new_items = []
|
107 |
+
for line in lines:
|
108 |
+
item = json.loads(line)
|
109 |
+
target_wav_name = f"{save_dir}/audios/{item['idx']}.flac"
|
110 |
+
# get prompt audio
|
111 |
+
if "prompt_audio_path" in item:
|
112 |
+
assert os.path.exists(item['prompt_audio_path']), f"prompt_audio_path {item['prompt_audio_path']} not found"
|
113 |
+
assert 'auto_prompt_audio_type' not in item, f"auto_prompt_audio_type and prompt_audio_path cannot be used together"
|
114 |
+
with torch.no_grad():
|
115 |
+
pmt_wav, vocal_wav, bgm_wav = separator.run(item['prompt_audio_path'])
|
116 |
+
item['raw_pmt_wav'] = pmt_wav
|
117 |
+
item['raw_vocal_wav'] = vocal_wav
|
118 |
+
item['raw_bgm_wav'] = bgm_wav
|
119 |
+
if pmt_wav.dim() == 2:
|
120 |
+
pmt_wav = pmt_wav[None]
|
121 |
+
if pmt_wav.dim() != 3:
|
122 |
+
raise ValueError("Melody wavs should have a shape [B, C, T].")
|
123 |
+
pmt_wav = list(pmt_wav)
|
124 |
+
if vocal_wav.dim() == 2:
|
125 |
+
vocal_wav = vocal_wav[None]
|
126 |
+
if vocal_wav.dim() != 3:
|
127 |
+
raise ValueError("Vocal wavs should have a shape [B, C, T].")
|
128 |
+
vocal_wav = list(vocal_wav)
|
129 |
+
if bgm_wav.dim() == 2:
|
130 |
+
bgm_wav = bgm_wav[None]
|
131 |
+
if bgm_wav.dim() != 3:
|
132 |
+
raise ValueError("BGM wavs should have a shape [B, C, T].")
|
133 |
+
bgm_wav = list(bgm_wav)
|
134 |
+
if type(pmt_wav) == list:
|
135 |
+
pmt_wav = torch.stack(pmt_wav, dim=0)
|
136 |
+
if type(vocal_wav) == list:
|
137 |
+
vocal_wav = torch.stack(vocal_wav, dim=0)
|
138 |
+
if type(bgm_wav) == list:
|
139 |
+
bgm_wav = torch.stack(bgm_wav, dim=0)
|
140 |
+
pmt_wav = pmt_wav
|
141 |
+
vocal_wav = vocal_wav
|
142 |
+
bgm_wav = bgm_wav
|
143 |
+
with torch.no_grad():
|
144 |
+
pmt_wav, _ = audio_tokenizer.encode(pmt_wav.cuda())
|
145 |
+
melody_is_wav = False
|
146 |
+
elif "auto_prompt_audio_type" in item:
|
147 |
+
assert item["auto_prompt_audio_type"] in auto_prompt_type, f"auto_prompt_audio_type {item['auto_prompt_audio_type']} not found"
|
148 |
+
if item["auto_prompt_audio_type"] == "Auto":
|
149 |
+
prompt_token = merge_prompt[np.random.randint(0, len(merge_prompt))]
|
150 |
+
else:
|
151 |
+
prompt_token = auto_prompt[item["auto_prompt_audio_type"]][np.random.randint(0, len(auto_prompt[item["auto_prompt_audio_type"]]))]
|
152 |
+
pmt_wav = prompt_token[:,[0],:]
|
153 |
+
vocal_wav = prompt_token[:,[1],:]
|
154 |
+
bgm_wav = prompt_token[:,[2],:]
|
155 |
+
melody_is_wav = False
|
156 |
+
else:
|
157 |
+
pmt_wav = None
|
158 |
+
vocal_wav = None
|
159 |
+
bgm_wav = None
|
160 |
+
melody_is_wav = True
|
161 |
+
item['pmt_wav'] = pmt_wav
|
162 |
+
item['vocal_wav'] = vocal_wav
|
163 |
+
item['bgm_wav'] = bgm_wav
|
164 |
+
item['melody_is_wav'] = melody_is_wav
|
165 |
+
item["idx"] = f"{item['idx']}"
|
166 |
+
item["wav_path"] = target_wav_name
|
167 |
+
new_items.append(item)
|
168 |
+
|
169 |
+
del audio_tokenizer
|
170 |
+
del separator
|
171 |
+
|
172 |
+
torch.cuda.empty_cache()
|
173 |
+
|
174 |
+
if "audio_tokenizer_checkpoint_sep" in cfg.keys():
|
175 |
+
seperate_tokenizer = builders.get_audio_tokenizer_model(cfg.audio_tokenizer_checkpoint_sep, cfg)
|
176 |
+
else:
|
177 |
+
seperate_tokenizer = None
|
178 |
+
|
179 |
+
if seperate_tokenizer is not None:
|
180 |
+
seperate_tokenizer = seperate_tokenizer.eval().cuda()
|
181 |
+
|
182 |
+
for item in new_items:
|
183 |
+
if "prompt_audio_path" in item:
|
184 |
+
with torch.no_grad():
|
185 |
+
vocal_wav, bgm_wav = seperate_tokenizer.encode(item['vocal_wav'].cuda(), item['bgm_wav'].cuda())
|
186 |
+
item['vocal_wav'] = vocal_wav
|
187 |
+
item['bgm_wav'] = bgm_wav
|
188 |
+
|
189 |
+
torch.cuda.empty_cache()
|
190 |
+
audiolm = builders.get_lm_model(cfg)
|
191 |
+
checkpoint = torch.load(ckpt_path, map_location='cpu')
|
192 |
+
audiolm_state_dict = {k.replace('audiolm.', ''): v for k, v in checkpoint.items() if k.startswith('audiolm')}
|
193 |
+
audiolm.load_state_dict(audiolm_state_dict, strict=False)
|
194 |
+
audiolm = audiolm.eval()
|
195 |
+
audiolm = audiolm.cuda().to(torch.float16)
|
196 |
+
|
197 |
+
model = CodecLM(name = "tmp",
|
198 |
+
lm = audiolm,
|
199 |
+
audiotokenizer = None,
|
200 |
+
max_duration = max_duration,
|
201 |
+
seperate_tokenizer = seperate_tokenizer,
|
202 |
+
)
|
203 |
+
|
204 |
cfg_coef = 1.5 #25
|
205 |
temp = 0.9
|
206 |
top_k = 50
|
|
|
214 |
os.makedirs(save_dir + "/audios", exist_ok=True)
|
215 |
os.makedirs(save_dir + "/jsonl", exist_ok=True)
|
216 |
|
217 |
+
for item in new_items:
|
218 |
+
lyric = item["gt_lyric"]
|
219 |
+
descriptions = item["descriptions"] if "descriptions" in item else None
|
220 |
+
pmt_wav = item['pmt_wav']
|
221 |
+
vocal_wav = item['vocal_wav']
|
222 |
+
bgm_wav = item['bgm_wav']
|
223 |
+
melody_is_wav = item['melody_is_wav']
|
224 |
+
target_wav_name = f"{save_dir}/audios/{item['idx']}.flac"
|
225 |
+
|
226 |
+
|
227 |
+
generate_inp = {
|
228 |
+
'lyrics': [lyric.replace(" ", " ")],
|
229 |
+
'descriptions': [descriptions],
|
230 |
+
'melody_wavs': pmt_wav,
|
231 |
+
'vocal_wavs': vocal_wav,
|
232 |
+
'bgm_wavs': bgm_wav,
|
233 |
+
'melody_is_wav': melody_is_wav,
|
234 |
+
}
|
235 |
+
start_time = time.time()
|
236 |
+
with torch.autocast(device_type="cuda", dtype=torch.float16):
|
237 |
+
with torch.no_grad():
|
238 |
+
tokens = model.generate(**generate_inp, return_tokens=True)
|
239 |
+
mid_time = time.time()
|
240 |
+
|
241 |
+
with torch.no_grad():
|
242 |
+
if 'raw_pmt_wav' in item:
|
243 |
+
if gen_type == 'separate':
|
244 |
+
wav_seperate = model.generate_audio(tokens, item['raw_pmt_wav'], item['raw_vocal_wav'], item['raw_bgm_wav'], chunked=True, gen_type='mixed')
|
245 |
+
wav_vocal = model.generate_audio(tokens, item['raw_pmt_wav'], item['raw_vocal_wav'], item['raw_bgm_wav'], chunked=True, gen_type='vocal')
|
246 |
+
wav_bgm = model.generate_audio(tokens, item['raw_pmt_wav'], item['raw_vocal_wav'], item['raw_bgm_wav'], chunked=True, gen_type='bgm')
|
247 |
+
elif gen_type == 'mixed':
|
248 |
+
wav_seperate = model.generate_audio(tokens, item['raw_pmt_wav'], item['raw_vocal_wav'], item['raw_bgm_wav'],chunked=True, gen_type=gen_type)
|
249 |
+
else:
|
250 |
+
wav_seperate = model.generate_audio(tokens,chunked=True, gen_type=gen_type)
|
251 |
+
del item['raw_pmt_wav']
|
252 |
+
del item['raw_vocal_wav']
|
253 |
+
del item['raw_bgm_wav']
|
254 |
+
else:
|
255 |
+
if gen_type == 'separate':
|
256 |
+
wav_vocal = model.generate_audio(tokens, chunked=True, gen_type='vocal')
|
257 |
+
wav_bgm = model.generate_audio(tokens, chunked=True, gen_type='bgm')
|
258 |
+
wav_seperate = model.generate_audio(tokens, chunked=True, gen_type='mixed')
|
259 |
+
else:
|
260 |
+
wav_seperate = model.generate_audio(tokens, chunked=True, gen_type=gen_type)
|
261 |
+
del item['pmt_wav']
|
262 |
+
del item['vocal_wav']
|
263 |
+
del item['bgm_wav']
|
264 |
+
del item['melody_is_wav']
|
265 |
+
end_time = time.time()
|
266 |
+
if gen_type == 'separate':
|
267 |
+
torchaudio.save(target_wav_name.replace('.flac', '_vocal.flac'), wav_vocal[0].cpu().float(), cfg.sample_rate)
|
268 |
+
torchaudio.save(target_wav_name.replace('.flac', '_bgm.flac'), wav_bgm[0].cpu().float(), cfg.sample_rate)
|
269 |
+
torchaudio.save(target_wav_name, wav_seperate[0].cpu().float(), cfg.sample_rate)
|
270 |
+
else:
|
271 |
+
torchaudio.save(target_wav_name, wav_seperate[0].cpu().float(), cfg.sample_rate)
|
272 |
+
|
273 |
+
print(f"process{item['idx']}, lm cost {mid_time - start_time}s, diffusion cost {end_time - mid_time}")
|
274 |
+
item["idx"] = f"{item['idx']}"
|
275 |
+
item["wav_path"] = target_wav_name
|
276 |
+
|
277 |
+
src_jsonl_name = os.path.split(input_jsonl)[-1]
|
278 |
+
with open(f"{save_dir}/jsonl/{src_jsonl_name}.jsonl", "w", encoding='utf-8') as fw:
|
279 |
+
for item in new_items:
|
280 |
+
fw.writelines(json.dumps(item, ensure_ascii=False)+"\n")
|
281 |
+
|
282 |
+
def generate_lowmem(args):
|
283 |
+
ckpt_path = args.ckpt_path
|
284 |
+
input_jsonl = args.input_jsonl
|
285 |
+
save_dir = args.save_dir
|
286 |
+
cfg_path = os.path.join(ckpt_path, 'config.yaml')
|
287 |
+
ckpt_path = os.path.join(ckpt_path, 'model.pt')
|
288 |
+
cfg = OmegaConf.load(cfg_path)
|
289 |
+
cfg.lm.use_flash_attn_2 = args.use_flash_attn
|
290 |
+
print(f"use_flash_attn: {args.use_flash_attn}")
|
291 |
+
cfg.mode = 'inference'
|
292 |
+
max_duration = cfg.max_dur
|
293 |
+
gen_type = args.generate_type
|
294 |
+
chunk_size = 128
|
295 |
+
use_audio_tokenizer = False
|
296 |
with open(input_jsonl, "r") as fp:
|
297 |
lines = fp.readlines()
|
298 |
+
for line in lines:
|
299 |
+
item = json.loads(line)
|
300 |
+
if "prompt_audio_path" in item:
|
301 |
+
use_audio_tokenizer = True
|
302 |
+
break
|
303 |
+
if use_audio_tokenizer:
|
304 |
+
separator = Separator()
|
305 |
+
audio_tokenizer = builders.get_audio_tokenizer_model(cfg.audio_tokenizer_checkpoint, cfg)
|
306 |
+
audio_tokenizer = audio_tokenizer.eval().cuda()
|
307 |
+
auto_prompt = torch.load('ckpt/prompt.pt')
|
308 |
+
merge_prompt = [item for sublist in auto_prompt.values() for item in sublist]
|
309 |
new_items = []
|
310 |
for line in lines:
|
311 |
item = json.loads(line)
|
312 |
target_wav_name = f"{save_dir}/audios/{item['idx']}.flac"
|
|
|
|
|
313 |
# get prompt audio
|
314 |
if "prompt_audio_path" in item:
|
315 |
assert os.path.exists(item['prompt_audio_path']), f"prompt_audio_path {item['prompt_audio_path']} not found"
|
316 |
assert 'auto_prompt_audio_type' not in item, f"auto_prompt_audio_type and prompt_audio_path cannot be used together"
|
317 |
+
with torch.no_grad():
|
318 |
+
pmt_wav, vocal_wav, bgm_wav = separator.run(item['prompt_audio_path'])
|
319 |
+
item['raw_pmt_wav'] = pmt_wav
|
320 |
+
item['raw_vocal_wav'] = vocal_wav
|
321 |
+
item['raw_bgm_wav'] = bgm_wav
|
322 |
+
if pmt_wav.dim() == 2:
|
323 |
+
pmt_wav = pmt_wav[None]
|
324 |
+
if pmt_wav.dim() != 3:
|
325 |
+
raise ValueError("Melody wavs should have a shape [B, C, T].")
|
326 |
+
pmt_wav = list(pmt_wav)
|
327 |
+
if vocal_wav.dim() == 2:
|
328 |
+
vocal_wav = vocal_wav[None]
|
329 |
+
if vocal_wav.dim() != 3:
|
330 |
+
raise ValueError("Vocal wavs should have a shape [B, C, T].")
|
331 |
+
vocal_wav = list(vocal_wav)
|
332 |
+
if bgm_wav.dim() == 2:
|
333 |
+
bgm_wav = bgm_wav[None]
|
334 |
+
if bgm_wav.dim() != 3:
|
335 |
+
raise ValueError("BGM wavs should have a shape [B, C, T].")
|
336 |
+
bgm_wav = list(bgm_wav)
|
337 |
+
if type(pmt_wav) == list:
|
338 |
+
pmt_wav = torch.stack(pmt_wav, dim=0)
|
339 |
+
if type(vocal_wav) == list:
|
340 |
+
vocal_wav = torch.stack(vocal_wav, dim=0)
|
341 |
+
if type(bgm_wav) == list:
|
342 |
+
bgm_wav = torch.stack(bgm_wav, dim=0)
|
343 |
+
with torch.no_grad():
|
344 |
+
pmt_wav, _ = audio_tokenizer.encode(pmt_wav.cuda())
|
345 |
+
melody_is_wav = False
|
346 |
elif "auto_prompt_audio_type" in item:
|
347 |
assert item["auto_prompt_audio_type"] in auto_prompt_type, f"auto_prompt_audio_type {item['auto_prompt_audio_type']} not found"
|
348 |
if item["auto_prompt_audio_type"] == "Auto":
|
|
|
358 |
vocal_wav = None
|
359 |
bgm_wav = None
|
360 |
melody_is_wav = True
|
361 |
+
item['pmt_wav'] = pmt_wav
|
362 |
+
item['vocal_wav'] = vocal_wav
|
363 |
+
item['bgm_wav'] = bgm_wav
|
364 |
+
item['melody_is_wav'] = melody_is_wav
|
365 |
+
item["idx"] = f"{item['idx']}"
|
366 |
+
item["wav_path"] = target_wav_name
|
367 |
+
new_items.append(item)
|
368 |
+
|
369 |
+
if use_audio_tokenizer:
|
370 |
+
del audio_tokenizer
|
371 |
+
del separator
|
372 |
+
|
373 |
+
torch.cuda.empty_cache()
|
374 |
+
|
375 |
+
if "audio_tokenizer_checkpoint_sep" in cfg.keys() and use_audio_tokenizer:
|
376 |
+
seperate_tokenizer = builders.get_audio_tokenizer_model(cfg.audio_tokenizer_checkpoint_sep, cfg)
|
377 |
+
else:
|
378 |
+
seperate_tokenizer = None
|
379 |
+
|
380 |
+
if seperate_tokenizer is not None:
|
381 |
+
seperate_tokenizer = seperate_tokenizer.eval().cuda()
|
382 |
+
|
383 |
+
for item in new_items:
|
384 |
+
if "prompt_audio_path" in item:
|
385 |
+
with torch.no_grad():
|
386 |
+
vocal_wav, bgm_wav = seperate_tokenizer.encode(item['vocal_wav'].cuda(), item['bgm_wav'].cuda())
|
387 |
+
item['vocal_wav'] = vocal_wav
|
388 |
+
item['bgm_wav'] = bgm_wav
|
389 |
+
|
390 |
+
if use_audio_tokenizer:
|
391 |
+
del seperate_tokenizer
|
392 |
+
|
393 |
+
torch.cuda.empty_cache()
|
394 |
+
|
395 |
+
# Define model or load pretrained model
|
396 |
+
audiolm = builders.get_lm_model(cfg)
|
397 |
+
checkpoint = torch.load(ckpt_path, map_location='cpu')
|
398 |
+
audiolm_state_dict = {k.replace('audiolm.', ''): v for k, v in checkpoint.items() if k.startswith('audiolm')}
|
399 |
+
audiolm.load_state_dict(audiolm_state_dict, strict=False)
|
400 |
+
audiolm = audiolm.eval()
|
401 |
+
|
402 |
+
offload_audiolm = True if 'offload' in cfg.keys() and 'audiolm' in cfg.offload else False
|
403 |
+
if offload_audiolm:
|
404 |
+
audiolm_offload_param = OffloadParamParse.parse_config(audiolm, cfg.offload.audiolm)
|
405 |
+
audiolm_offload_param.show()
|
406 |
+
offload_profiler = OffloadProfiler(device_index=0, **(audiolm_offload_param.init_param_dict()))
|
407 |
+
offload_profiler.offload_layer(**(audiolm_offload_param.offload_layer_param_dict()))
|
408 |
+
offload_profiler.clean_cache_wrapper(**(audiolm_offload_param.clean_cache_param_dict()))
|
409 |
+
else:
|
410 |
+
audiolm = audiolm.cuda().to(torch.float16)
|
411 |
+
|
412 |
+
model = CodecLM(name = "tmp",
|
413 |
+
lm = audiolm,
|
414 |
+
audiotokenizer = None,
|
415 |
+
max_duration = max_duration,
|
416 |
+
seperate_tokenizer = None,
|
417 |
+
)
|
418 |
+
|
419 |
+
cfg_coef = 1.5 #25
|
420 |
+
temp = 0.9
|
421 |
+
top_k = 50
|
422 |
+
top_p = 0.0
|
423 |
+
record_tokens = True
|
424 |
+
record_window = 50
|
425 |
+
|
426 |
+
|
427 |
+
model.set_generation_params(duration=max_duration, extend_stride=5, temperature=temp, cfg_coef=cfg_coef,
|
428 |
+
top_k=top_k, top_p=top_p, record_tokens=record_tokens, record_window=record_window)
|
429 |
+
os.makedirs(save_dir, exist_ok=True)
|
430 |
+
os.makedirs(save_dir + "/audios", exist_ok=True)
|
431 |
+
os.makedirs(save_dir + "/jsonl", exist_ok=True)
|
432 |
+
|
433 |
+
|
434 |
+
for item in new_items:
|
435 |
+
lyric = item["gt_lyric"]
|
436 |
+
descriptions = item["descriptions"] if "descriptions" in item else None
|
437 |
+
pmt_wav = item['pmt_wav']
|
438 |
+
vocal_wav = item['vocal_wav']
|
439 |
+
bgm_wav = item['bgm_wav']
|
440 |
+
melody_is_wav = item['melody_is_wav']
|
441 |
|
442 |
generate_inp = {
|
443 |
'lyrics': [lyric.replace(" ", " ")],
|
|
|
447 |
'bgm_wavs': bgm_wav,
|
448 |
'melody_is_wav': melody_is_wav,
|
449 |
}
|
|
|
450 |
with torch.autocast(device_type="cuda", dtype=torch.float16):
|
451 |
+
with torch.no_grad():
|
452 |
+
tokens = model.generate(**generate_inp, return_tokens=True)
|
453 |
+
if offload_audiolm:
|
454 |
+
offload_profiler.reset_empty_cache_mem_line()
|
455 |
+
item['tokens'] = tokens
|
456 |
+
if offload_audiolm:
|
457 |
+
offload_profiler.stop()
|
458 |
+
del offload_profiler
|
459 |
+
del audiolm_offload_param
|
460 |
+
del model
|
461 |
+
audiolm = audiolm.cpu()
|
462 |
+
del audiolm
|
463 |
+
del checkpoint
|
464 |
+
gc.collect()
|
465 |
+
torch.cuda.empty_cache()
|
466 |
+
|
467 |
+
seperate_tokenizer = builders.get_audio_tokenizer_model_cpu(cfg.audio_tokenizer_checkpoint_sep, cfg)
|
468 |
+
device = "cuda:0"
|
469 |
+
seperate_tokenizer.model.device = device
|
470 |
+
seperate_tokenizer.model.vae = seperate_tokenizer.model.vae.to(device)
|
471 |
+
seperate_tokenizer.model.model.device = torch.device(device)
|
472 |
+
seperate_tokenizer = seperate_tokenizer.eval()
|
473 |
+
|
474 |
+
offload_wav_tokenizer_diffusion = True if 'offload' in cfg.keys() and 'wav_tokenizer_diffusion' in cfg.offload else False
|
475 |
+
if offload_wav_tokenizer_diffusion:
|
476 |
+
sep_offload_param = OffloadParamParse.parse_config(seperate_tokenizer, cfg.offload.wav_tokenizer_diffusion)
|
477 |
+
sep_offload_param.show()
|
478 |
+
sep_offload_profiler = OffloadProfiler(device_index=0, **(sep_offload_param.init_param_dict()))
|
479 |
+
sep_offload_profiler.offload_layer(**(sep_offload_param.offload_layer_param_dict()))
|
480 |
+
sep_offload_profiler.clean_cache_wrapper(**(sep_offload_param.clean_cache_param_dict()))
|
481 |
+
else:
|
482 |
+
seperate_tokenizer.model.model = seperate_tokenizer.model.model.to(device)
|
483 |
+
|
484 |
+
model = CodecLM(name = "tmp",
|
485 |
+
lm = None,
|
486 |
+
audiotokenizer = None,
|
487 |
+
max_duration = max_duration,
|
488 |
+
seperate_tokenizer = seperate_tokenizer,
|
489 |
+
)
|
490 |
|
491 |
+
for item in new_items:
|
492 |
with torch.no_grad():
|
493 |
+
if 'raw_pmt_wav' in item:
|
494 |
+
if gen_type == 'separate':
|
495 |
+
wav_seperate = model.generate_audio(item['tokens'], item['raw_pmt_wav'], item['raw_vocal_wav'], item['raw_bgm_wav'],chunked=True, gen_type='mixed')
|
496 |
+
wav_vocal = model.generate_audio(item['tokens'],chunked=True, gen_type='vocal')
|
497 |
+
wav_bgm = model.generate_audio(item['tokens'], chunked=True, gen_type='bgm')
|
498 |
+
elif gen_type == 'mixed':
|
499 |
+
wav_seperate = model.generate_audio(item['tokens'], item['raw_pmt_wav'], item['raw_vocal_wav'], item['raw_bgm_wav'],chunked=True, gen_type=gen_type)
|
500 |
+
else:
|
501 |
+
wav_seperate = model.generate_audio(item['tokens'], chunked=True, gen_type=gen_type)
|
502 |
+
del item['raw_pmt_wav']
|
503 |
+
del item['raw_vocal_wav']
|
504 |
+
del item['raw_bgm_wav']
|
505 |
else:
|
506 |
+
if gen_type == 'separate':
|
507 |
+
wav_vocal = model.generate_audio(item['tokens'], chunked=True, gen_type='vocal')
|
508 |
+
wav_bgm = model.generate_audio(item['tokens'], chunked=True, gen_type='bgm')
|
509 |
+
wav_seperate = model.generate_audio(item['tokens'], chunked=True, gen_type='mixed')
|
510 |
+
else:
|
511 |
+
wav_seperate = model.generate_audio(item['tokens'], chunked=True, gen_type=gen_type)
|
512 |
+
if gen_type == 'separate':
|
513 |
+
torchaudio.save(item['wav_path'].replace('.flac', '_vocal.flac'), wav_vocal[0].cpu().float(), cfg.sample_rate)
|
514 |
+
torchaudio.save(item['wav_path'].replace('.flac', '_bgm.flac'), wav_bgm[0].cpu().float(), cfg.sample_rate)
|
515 |
+
torchaudio.save(item['wav_path'], wav_seperate[0].cpu().float(), cfg.sample_rate)
|
516 |
+
else:
|
517 |
+
torchaudio.save(item['wav_path'], wav_seperate[0].cpu().float(), cfg.sample_rate)
|
518 |
+
del item['tokens']
|
519 |
+
del item['pmt_wav']
|
520 |
+
del item['vocal_wav']
|
521 |
+
del item['bgm_wav']
|
522 |
+
del item['melody_is_wav']
|
523 |
+
if offload_wav_tokenizer_diffusion:
|
524 |
+
sep_offload_profiler.reset_empty_cache_mem_line()
|
525 |
|
526 |
+
if offload_wav_tokenizer_diffusion:
|
527 |
+
sep_offload_profiler.stop()
|
528 |
+
torch.cuda.empty_cache()
|
529 |
src_jsonl_name = os.path.split(input_jsonl)[-1]
|
530 |
with open(f"{save_dir}/jsonl/{src_jsonl_name}.jsonl", "w", encoding='utf-8') as fw:
|
531 |
for item in new_items:
|
532 |
fw.writelines(json.dumps(item, ensure_ascii=False)+"\n")
|
533 |
+
|
534 |
+
|
535 |
+
if __name__ == "__main__":
|
536 |
+
torch.backends.cudnn.enabled = False
|
537 |
+
OmegaConf.register_new_resolver("eval", lambda x: eval(x))
|
538 |
+
OmegaConf.register_new_resolver("concat", lambda *x: [xxx for xx in x for xxx in xx])
|
539 |
+
OmegaConf.register_new_resolver("get_fname", lambda: os.path.splitext(os.path.basename(sys.argv[1]))[0])
|
540 |
+
OmegaConf.register_new_resolver("load_yaml", lambda x: list(OmegaConf.load(x)))
|
541 |
+
np.random.seed(int(time.time()))
|
542 |
+
# 解析命令行参数
|
543 |
+
args = parse_args()
|
544 |
+
if torch.cuda.is_available():
|
545 |
+
device = torch.cuda.current_device()
|
546 |
+
reserved = torch.cuda.memory_reserved(device)
|
547 |
+
total = torch.cuda.get_device_properties(device).total_memory
|
548 |
+
res_mem = (total - reserved) / 1024 / 1024 / 1024
|
549 |
+
print(f"reserved memory: {res_mem}GB")
|
550 |
+
|
551 |
+
model_name = args.ckpt_path.split("/")[-1]
|
552 |
+
assert model_name in ['songgeneration_base'], f'{model_name} is not supported, currently only songgeneration_base is supported'
|
553 |
+
if model_name == 'songgeneration_base':
|
554 |
+
if res_mem > 24 and not args.low_mem:
|
555 |
+
print("use generate")
|
556 |
+
generate(args)
|
557 |
+
else:
|
558 |
+
from codeclm.utils.offload_profiler import OffloadProfiler, OffloadParamParse
|
559 |
+
print("use generate_lowmem")
|
560 |
+
generate_lowmem(args)
|
561 |
+
|
562 |
+
else:
|
563 |
+
print("CUDA is not available")
|
564 |
+
exit()
|
565 |
+
|
generate.sh
CHANGED
@@ -7,5 +7,66 @@ export PYTHONPATH="$(pwd)/codeclm/tokenizer/":"$(pwd)":"$(pwd)/codeclm/tokenizer
|
|
7 |
CKPT_PATH=$1
|
8 |
JSONL=$2
|
9 |
SAVE_DIR=$3
|
10 |
-
|
11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
CKPT_PATH=$1
|
8 |
JSONL=$2
|
9 |
SAVE_DIR=$3
|
10 |
+
USE_FLASH_ATTN="True"
|
11 |
+
LOW_MEM="False"
|
12 |
+
GENERATE_TYPE="mixed"
|
13 |
+
for arg in "$@"; do
|
14 |
+
if [[ $arg == "--not_use_flash_attn" ]]; then
|
15 |
+
USE_FLASH_ATTN="False"
|
16 |
+
fi
|
17 |
+
done
|
18 |
+
for arg in "$@"; do
|
19 |
+
if [[ $arg == "--low_mem" ]]; then
|
20 |
+
LOW_MEM="True"
|
21 |
+
fi
|
22 |
+
done
|
23 |
+
for arg in "$@"; do
|
24 |
+
if [[ $arg == "--separate" ]]; then
|
25 |
+
GENERATE_TYPE="separate"
|
26 |
+
fi
|
27 |
+
done
|
28 |
+
for arg in "$@"; do
|
29 |
+
if [[ $arg == "--bgm" ]]; then
|
30 |
+
GENERATE_TYPE="bgm"
|
31 |
+
fi
|
32 |
+
done
|
33 |
+
for arg in "$@"; do
|
34 |
+
if [[ $arg == "--vocal" ]]; then
|
35 |
+
GENERATE_TYPE="vocal"
|
36 |
+
fi
|
37 |
+
done
|
38 |
+
|
39 |
+
|
40 |
+
if [ "$USE_FLASH_ATTN" == "True" ] && [ "$LOW_MEM" == "True" ]; then
|
41 |
+
echo "Use Flash Attention + Low Memory Mode"
|
42 |
+
python3 generate.py \
|
43 |
+
--ckpt_path $CKPT_PATH \
|
44 |
+
--input_jsonl $JSONL \
|
45 |
+
--save_dir $SAVE_DIR \
|
46 |
+
--generate_type $GENERATE_TYPE \
|
47 |
+
--use_flash_attn \
|
48 |
+
--low_mem
|
49 |
+
elif [ "$USE_FLASH_ATTN" == "True" ] && [ "$LOW_MEM" == "False" ]; then
|
50 |
+
echo "Use Flash Attention + Auto Memory Mode"
|
51 |
+
python3 generate.py \
|
52 |
+
--ckpt_path $CKPT_PATH \
|
53 |
+
--input_jsonl $JSONL \
|
54 |
+
--save_dir $SAVE_DIR \
|
55 |
+
--generate_type $GENERATE_TYPE \
|
56 |
+
--use_flash_attn
|
57 |
+
elif [ "$USE_FLASH_ATTN" == "False" ] && [ "$LOW_MEM" == "False" ]; then
|
58 |
+
echo "Not Use Flash Attention + Auto Memory Mode"
|
59 |
+
python3 generate.py \
|
60 |
+
--ckpt_path $CKPT_PATH \
|
61 |
+
--input_jsonl $JSONL \
|
62 |
+
--generate_type $GENERATE_TYPE \
|
63 |
+
--save_dir $SAVE_DIR
|
64 |
+
elif [ "$USE_FLASH_ATTN" == "False" ] && [ "$LOW_MEM" == "True" ]; then
|
65 |
+
echo "Not Use Flash Attention + Low Memory Mode"
|
66 |
+
python3 generate.py \
|
67 |
+
--ckpt_path $CKPT_PATH \
|
68 |
+
--input_jsonl $JSONL \
|
69 |
+
--save_dir $SAVE_DIR \
|
70 |
+
--generate_type $GENERATE_TYPE \
|
71 |
+
--low_mem
|
72 |
+
fi
|
generate_lowmem.py
DELETED
@@ -1,241 +0,0 @@
|
|
1 |
-
import sys
|
2 |
-
import os
|
3 |
-
|
4 |
-
import time
|
5 |
-
import json
|
6 |
-
import torch
|
7 |
-
import torchaudio
|
8 |
-
import numpy as np
|
9 |
-
from omegaconf import OmegaConf
|
10 |
-
from codeclm.models import builders
|
11 |
-
|
12 |
-
from codeclm.trainer.codec_song_pl import CodecLM_PL
|
13 |
-
from codeclm.models import CodecLM
|
14 |
-
from third_party.demucs.models.pretrained import get_model_from_yaml
|
15 |
-
|
16 |
-
auto_prompt_type = ['Pop', 'R&B', 'Dance', 'Jazz', 'Folk', 'Rock', 'Chinese Style', 'Chinese Tradition', 'Metal', 'Reggae', 'Chinese Opera', 'Auto']
|
17 |
-
|
18 |
-
class Separator:
|
19 |
-
def __init__(self, dm_model_path='third_party/demucs/ckpt/htdemucs.pth', dm_config_path='third_party/demucs/ckpt/htdemucs.yaml', gpu_id=0) -> None:
|
20 |
-
if torch.cuda.is_available() and gpu_id < torch.cuda.device_count():
|
21 |
-
self.device = torch.device(f"cuda:{gpu_id}")
|
22 |
-
else:
|
23 |
-
self.device = torch.device("cpu")
|
24 |
-
self.demucs_model = self.init_demucs_model(dm_model_path, dm_config_path)
|
25 |
-
|
26 |
-
def init_demucs_model(self, model_path, config_path):
|
27 |
-
model = get_model_from_yaml(config_path, model_path)
|
28 |
-
model.to(self.device)
|
29 |
-
model.eval()
|
30 |
-
return model
|
31 |
-
|
32 |
-
def load_audio(self, f):
|
33 |
-
a, fs = torchaudio.load(f)
|
34 |
-
if (fs != 48000):
|
35 |
-
a = torchaudio.functional.resample(a, fs, 48000)
|
36 |
-
if a.shape[-1] >= 48000*10:
|
37 |
-
a = a[..., :48000*10]
|
38 |
-
else:
|
39 |
-
a = torch.cat([a, a], -1)
|
40 |
-
return a[:, 0:48000*10]
|
41 |
-
|
42 |
-
def run(self, audio_path, output_dir='tmp', ext=".flac"):
|
43 |
-
os.makedirs(output_dir, exist_ok=True)
|
44 |
-
name, _ = os.path.splitext(os.path.split(audio_path)[-1])
|
45 |
-
output_paths = []
|
46 |
-
|
47 |
-
for stem in self.demucs_model.sources:
|
48 |
-
output_path = os.path.join(output_dir, f"{name}_{stem}{ext}")
|
49 |
-
if os.path.exists(output_path):
|
50 |
-
output_paths.append(output_path)
|
51 |
-
if len(output_paths) == 1: # 4
|
52 |
-
vocal_path = output_paths[0]
|
53 |
-
else:
|
54 |
-
drums_path, bass_path, other_path, vocal_path = self.demucs_model.separate(audio_path, output_dir, device=self.device)
|
55 |
-
for path in [drums_path, bass_path, other_path]:
|
56 |
-
os.remove(path)
|
57 |
-
full_audio = self.load_audio(audio_path)
|
58 |
-
vocal_audio = self.load_audio(vocal_path)
|
59 |
-
bgm_audio = full_audio - vocal_audio
|
60 |
-
return full_audio, vocal_audio, bgm_audio
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
if __name__ == "__main__":
|
65 |
-
torch.backends.cudnn.enabled = False
|
66 |
-
OmegaConf.register_new_resolver("eval", lambda x: eval(x))
|
67 |
-
OmegaConf.register_new_resolver("concat", lambda *x: [xxx for xx in x for xxx in xx])
|
68 |
-
OmegaConf.register_new_resolver("get_fname", lambda: os.path.splitext(os.path.basename(sys.argv[1]))[0])
|
69 |
-
OmegaConf.register_new_resolver("load_yaml", lambda x: list(OmegaConf.load(x)))
|
70 |
-
np.random.seed(int(time.time()))
|
71 |
-
ckpt_path = sys.argv[1]
|
72 |
-
input_jsonl = sys.argv[2]
|
73 |
-
save_dir = sys.argv[3]
|
74 |
-
gen_type = sys.argv[4] if len(sys.argv) > 4 else "all"
|
75 |
-
cfg_path = os.path.join(ckpt_path, 'config.yaml')
|
76 |
-
ckpt_path = os.path.join(ckpt_path, 'model.pt')
|
77 |
-
cfg = OmegaConf.load(cfg_path)
|
78 |
-
cfg.mode = 'inference'
|
79 |
-
max_duration = cfg.max_dur
|
80 |
-
|
81 |
-
separator = Separator()
|
82 |
-
auto_prompt = torch.load('ckpt/prompt.pt')
|
83 |
-
audio_tokenizer = builders.get_audio_tokenizer_model(cfg.audio_tokenizer_checkpoint, cfg)
|
84 |
-
if "audio_tokenizer_checkpoint_sep" in cfg.keys():
|
85 |
-
seperate_tokenizer = builders.get_audio_tokenizer_model(cfg.audio_tokenizer_checkpoint_sep, cfg)
|
86 |
-
else:
|
87 |
-
seperate_tokenizer = None
|
88 |
-
audio_tokenizer = audio_tokenizer.eval().cuda()
|
89 |
-
if seperate_tokenizer is not None:
|
90 |
-
seperate_tokenizer = seperate_tokenizer.eval().cuda()
|
91 |
-
|
92 |
-
merge_prompt = [item for sublist in auto_prompt.values() for item in sublist]
|
93 |
-
with open(input_jsonl, "r") as fp:
|
94 |
-
lines = fp.readlines()
|
95 |
-
new_items = []
|
96 |
-
for line in lines:
|
97 |
-
item = json.loads(line)
|
98 |
-
target_wav_name = f"{save_dir}/audios/{item['idx']}.flac"
|
99 |
-
# get prompt audio
|
100 |
-
if "prompt_audio_path" in item:
|
101 |
-
assert os.path.exists(item['prompt_audio_path']), f"prompt_audio_path {item['prompt_audio_path']} not found"
|
102 |
-
assert 'auto_prompt_audio_type' not in item, f"auto_prompt_audio_type and prompt_audio_path cannot be used together"
|
103 |
-
pmt_wav, vocal_wav, bgm_wav = separator.run(item['prompt_audio_path'])
|
104 |
-
item['raw_pmt_wav'] = pmt_wav
|
105 |
-
item['raw_vocal_wav'] = vocal_wav
|
106 |
-
item['raw_bgm_wav'] = bgm_wav
|
107 |
-
if pmt_wav.dim() == 2:
|
108 |
-
pmt_wav = pmt_wav[None]
|
109 |
-
if pmt_wav.dim() != 3:
|
110 |
-
raise ValueError("Melody wavs should have a shape [B, C, T].")
|
111 |
-
pmt_wav = list(pmt_wav)
|
112 |
-
if vocal_wav.dim() == 2:
|
113 |
-
vocal_wav = vocal_wav[None]
|
114 |
-
if vocal_wav.dim() != 3:
|
115 |
-
raise ValueError("Vocal wavs should have a shape [B, C, T].")
|
116 |
-
vocal_wav = list(vocal_wav)
|
117 |
-
if bgm_wav.dim() == 2:
|
118 |
-
bgm_wav = bgm_wav[None]
|
119 |
-
if bgm_wav.dim() != 3:
|
120 |
-
raise ValueError("BGM wavs should have a shape [B, C, T].")
|
121 |
-
bgm_wav = list(bgm_wav)
|
122 |
-
if type(pmt_wav) == list:
|
123 |
-
pmt_wav = torch.stack(pmt_wav, dim=0)
|
124 |
-
if type(vocal_wav) == list:
|
125 |
-
vocal_wav = torch.stack(vocal_wav, dim=0)
|
126 |
-
if type(bgm_wav) == list:
|
127 |
-
bgm_wav = torch.stack(bgm_wav, dim=0)
|
128 |
-
pmt_wav = pmt_wav.cuda()
|
129 |
-
vocal_wav = vocal_wav.cuda()
|
130 |
-
bgm_wav = bgm_wav.cuda()
|
131 |
-
pmt_wav, _ = audio_tokenizer.encode(pmt_wav)
|
132 |
-
vocal_wav, bgm_wav = seperate_tokenizer.encode(vocal_wav, bgm_wav)
|
133 |
-
melody_is_wav = False
|
134 |
-
elif "auto_prompt_audio_type" in item:
|
135 |
-
assert item["auto_prompt_audio_type"] in auto_prompt_type, f"auto_prompt_audio_type {item['auto_prompt_audio_type']} not found"
|
136 |
-
if item["auto_prompt_audio_type"] == "Auto":
|
137 |
-
prompt_token = merge_prompt[np.random.randint(0, len(merge_prompt))]
|
138 |
-
else:
|
139 |
-
prompt_token = auto_prompt[item["auto_prompt_audio_type"]][np.random.randint(0, len(auto_prompt[item["auto_prompt_audio_type"]]))]
|
140 |
-
pmt_wav = prompt_token[:,[0],:]
|
141 |
-
vocal_wav = prompt_token[:,[1],:]
|
142 |
-
bgm_wav = prompt_token[:,[2],:]
|
143 |
-
melody_is_wav = False
|
144 |
-
else:
|
145 |
-
pmt_wav = None
|
146 |
-
vocal_wav = None
|
147 |
-
bgm_wav = None
|
148 |
-
melody_is_wav = True
|
149 |
-
item['pmt_wav'] = pmt_wav
|
150 |
-
item['vocal_wav'] = vocal_wav
|
151 |
-
item['bgm_wav'] = bgm_wav
|
152 |
-
item['melody_is_wav'] = melody_is_wav
|
153 |
-
item["idx"] = f"{item['idx']}"
|
154 |
-
item["wav_path"] = target_wav_name
|
155 |
-
new_items.append(item)
|
156 |
-
|
157 |
-
del audio_tokenizer
|
158 |
-
del seperate_tokenizer
|
159 |
-
del separator
|
160 |
-
|
161 |
-
# Define model or load pretrained model
|
162 |
-
model_light = CodecLM_PL(cfg, ckpt_path)
|
163 |
-
model_light = model_light.eval()
|
164 |
-
model_light.audiolm.cfg = cfg
|
165 |
-
model = CodecLM(name = "tmp",
|
166 |
-
lm = model_light.audiolm,
|
167 |
-
audiotokenizer = None,
|
168 |
-
max_duration = max_duration,
|
169 |
-
seperate_tokenizer = None,
|
170 |
-
)
|
171 |
-
del model_light
|
172 |
-
model.lm = model.lm.cuda().to(torch.float16)
|
173 |
-
|
174 |
-
cfg_coef = 1.5 #25
|
175 |
-
temp = 0.9
|
176 |
-
top_k = 50
|
177 |
-
top_p = 0.0
|
178 |
-
record_tokens = True
|
179 |
-
record_window = 50
|
180 |
-
|
181 |
-
model.set_generation_params(duration=max_duration, extend_stride=5, temperature=temp, cfg_coef=cfg_coef,
|
182 |
-
top_k=top_k, top_p=top_p, record_tokens=record_tokens, record_window=record_window)
|
183 |
-
os.makedirs(save_dir, exist_ok=True)
|
184 |
-
os.makedirs(save_dir + "/audios", exist_ok=True)
|
185 |
-
os.makedirs(save_dir + "/jsonl", exist_ok=True)
|
186 |
-
|
187 |
-
|
188 |
-
for item in new_items:
|
189 |
-
lyric = item["gt_lyric"]
|
190 |
-
descriptions = item["descriptions"] if "descriptions" in item else None
|
191 |
-
pmt_wav = item['pmt_wav']
|
192 |
-
vocal_wav = item['vocal_wav']
|
193 |
-
bgm_wav = item['bgm_wav']
|
194 |
-
melody_is_wav = item['melody_is_wav']
|
195 |
-
|
196 |
-
generate_inp = {
|
197 |
-
'lyrics': [lyric.replace(" ", " ")],
|
198 |
-
'descriptions': [descriptions],
|
199 |
-
'melody_wavs': pmt_wav,
|
200 |
-
'vocal_wavs': vocal_wav,
|
201 |
-
'bgm_wavs': bgm_wav,
|
202 |
-
'melody_is_wav': melody_is_wav,
|
203 |
-
}
|
204 |
-
with torch.autocast(device_type="cuda", dtype=torch.float16):
|
205 |
-
tokens = model.generate(**generate_inp, return_tokens=True)
|
206 |
-
item['tokens'] = tokens
|
207 |
-
|
208 |
-
del model
|
209 |
-
torch.cuda.empty_cache()
|
210 |
-
|
211 |
-
|
212 |
-
seperate_tokenizer = builders.get_audio_tokenizer_model(cfg.audio_tokenizer_checkpoint_sep, cfg)
|
213 |
-
seperate_tokenizer = seperate_tokenizer.eval().cuda()
|
214 |
-
|
215 |
-
model = CodecLM(name = "tmp",
|
216 |
-
lm = None,
|
217 |
-
audiotokenizer = None,
|
218 |
-
max_duration = max_duration,
|
219 |
-
seperate_tokenizer = seperate_tokenizer,
|
220 |
-
)
|
221 |
-
for item in new_items:
|
222 |
-
with torch.no_grad():
|
223 |
-
if 'raw_pmt_wav' in item:
|
224 |
-
wav_seperate = model.generate_audio(item['tokens'], item['raw_pmt_wav'], item['raw_vocal_wav'], item['raw_bgm_wav'], chunked=True, gen_type=gen_type)
|
225 |
-
del item['raw_pmt_wav']
|
226 |
-
del item['raw_vocal_wav']
|
227 |
-
del item['raw_bgm_wav']
|
228 |
-
else:
|
229 |
-
wav_seperate = model.generate_audio(item['tokens'], chunked=True, gen_type=gen_type)
|
230 |
-
torchaudio.save(item['wav_path'], wav_seperate[0].cpu().float(), cfg.sample_rate)
|
231 |
-
del item['tokens']
|
232 |
-
del item['pmt_wav']
|
233 |
-
del item['vocal_wav']
|
234 |
-
del item['bgm_wav']
|
235 |
-
del item['melody_is_wav']
|
236 |
-
|
237 |
-
torch.cuda.empty_cache()
|
238 |
-
src_jsonl_name = os.path.split(input_jsonl)[-1]
|
239 |
-
with open(f"{save_dir}/jsonl/{src_jsonl_name}.jsonl", "w", encoding='utf-8') as fw:
|
240 |
-
for item in new_items:
|
241 |
-
fw.writelines(json.dumps(item, ensure_ascii=False)+"\n")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
generate_lowmem.sh
DELETED
@@ -1,11 +0,0 @@
|
|
1 |
-
export USER=root
|
2 |
-
export PYTHONDONTWRITEBYTECODE=1
|
3 |
-
export TRANSFORMERS_CACHE="$(pwd)/third_party/hub"
|
4 |
-
export NCCL_HOME=/usr/local/tccl
|
5 |
-
export PYTHONPATH="$(pwd)/codeclm/tokenizer/":"$(pwd)":"$(pwd)/codeclm/tokenizer/Flow1dVAE/":"$(pwd)/codeclm/tokenizer/":$PYTHONPATH
|
6 |
-
|
7 |
-
CKPT_PATH=$1
|
8 |
-
JSONL=$2
|
9 |
-
SAVE_DIR=$3
|
10 |
-
GEN_TYEP=$4
|
11 |
-
python3 generate_lowmem.py $CKPT_PATH $JSONL $SAVE_DIR $GEN_TYEP
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tools/gradio/app.py
CHANGED
@@ -49,7 +49,7 @@ with open(op.join(APP_DIR, 'conf/vocab.yaml'), 'r', encoding='utf-8') as file:
|
|
49 |
STRUCTS = yaml.safe_load(file)
|
50 |
|
51 |
|
52 |
-
def generate_song(lyric, description=None, prompt_audio=None, genre=None, cfg_coef=None, temperature=None, top_k=None, gen_type="
|
53 |
global MODEL
|
54 |
global STRUCTS
|
55 |
params = {'cfg_coef':cfg_coef, 'temperature':temperature, 'top_k':top_k}
|
@@ -240,4 +240,3 @@ lyrics
|
|
240 |
# 启动应用
|
241 |
if __name__ == "__main__":
|
242 |
demo.launch(server_name="0.0.0.0", server_port=8081)
|
243 |
-
|
|
|
49 |
STRUCTS = yaml.safe_load(file)
|
50 |
|
51 |
|
52 |
+
def generate_song(lyric, description=None, prompt_audio=None, genre=None, cfg_coef=None, temperature=None, top_k=None, gen_type="mixed", progress=gr.Progress(track_tqdm=True)):
|
53 |
global MODEL
|
54 |
global STRUCTS
|
55 |
params = {'cfg_coef':cfg_coef, 'temperature':temperature, 'top_k':top_k}
|
|
|
240 |
# 启动应用
|
241 |
if __name__ == "__main__":
|
242 |
demo.launch(server_name="0.0.0.0", server_port=8081)
|
|
tools/gradio/levo_inference.py
CHANGED
@@ -62,7 +62,7 @@ class LeVoInference(torch.nn.Module):
|
|
62 |
|
63 |
self.model.set_generation_params(**self.default_params)
|
64 |
|
65 |
-
def forward(self, lyric: str, description: str = None, prompt_audio_path: os.PathLike = None, genre: str = None, auto_prompt_path: os.PathLike = None, gen_type: str = "
|
66 |
params = {**self.default_params, **params}
|
67 |
self.model.set_generation_params(**params)
|
68 |
|
|
|
62 |
|
63 |
self.model.set_generation_params(**self.default_params)
|
64 |
|
65 |
+
def forward(self, lyric: str, description: str = None, prompt_audio_path: os.PathLike = None, genre: str = None, auto_prompt_path: os.PathLike = None, gen_type: str = "mixed", params = dict()):
|
66 |
params = {**self.default_params, **params}
|
67 |
self.model.set_generation_params(**params)
|
68 |
|
tools/gradio/levo_inference_lowmem.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
import os
|
|
|
2 |
import sys
|
3 |
|
4 |
import torch
|
@@ -12,6 +13,7 @@ from codeclm.models import CodecLM
|
|
12 |
from codeclm.models import builders
|
13 |
|
14 |
from separator import Separator
|
|
|
15 |
|
16 |
|
17 |
class LeVoInference(torch.nn.Module):
|
@@ -40,24 +42,28 @@ class LeVoInference(torch.nn.Module):
|
|
40 |
)
|
41 |
|
42 |
|
43 |
-
def forward(self, lyric: str, description: str = None, prompt_audio_path: os.PathLike = None, genre: str = None, auto_prompt_path: os.PathLike = None, gen_type: str = "
|
44 |
if prompt_audio_path is not None and os.path.exists(prompt_audio_path):
|
45 |
separator = Separator()
|
46 |
audio_tokenizer = builders.get_audio_tokenizer_model(self.cfg.audio_tokenizer_checkpoint, self.cfg)
|
47 |
audio_tokenizer = audio_tokenizer.eval().cuda()
|
48 |
-
seperate_tokenizer = builders.get_audio_tokenizer_model(self.cfg.audio_tokenizer_checkpoint_sep, self.cfg)
|
49 |
-
seperate_tokenizer = seperate_tokenizer.eval().cuda()
|
50 |
pmt_wav, vocal_wav, bgm_wav = separator.run(prompt_audio_path)
|
51 |
pmt_wav = pmt_wav.cuda()
|
52 |
vocal_wav = vocal_wav.cuda()
|
53 |
bgm_wav = bgm_wav.cuda()
|
54 |
-
|
55 |
-
|
56 |
-
melody_is_wav = False
|
57 |
-
melody_is_wav = False
|
58 |
del audio_tokenizer
|
59 |
-
del seperate_tokenizer
|
60 |
del separator
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
elif genre is not None and auto_prompt_path is not None:
|
62 |
auto_prompt = torch.load(auto_prompt_path)
|
63 |
merge_prompt = [item for sublist in auto_prompt.values() for item in sublist]
|
@@ -75,17 +81,28 @@ class LeVoInference(torch.nn.Module):
|
|
75 |
bgm_wav = None
|
76 |
melody_is_wav = True
|
77 |
|
78 |
-
|
79 |
-
|
80 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
81 |
model = CodecLM(name = "tmp",
|
82 |
-
lm =
|
83 |
audiotokenizer = None,
|
84 |
max_duration = self.max_duration,
|
85 |
seperate_tokenizer = None,
|
86 |
)
|
87 |
-
del model_light
|
88 |
-
model.lm = model.lm.cuda().to(torch.float16)
|
89 |
params = {**self.default_params, **params}
|
90 |
model.set_generation_params(**params)
|
91 |
|
@@ -99,28 +116,53 @@ class LeVoInference(torch.nn.Module):
|
|
99 |
}
|
100 |
|
101 |
with torch.autocast(device_type="cuda", dtype=torch.float16):
|
102 |
-
|
103 |
-
|
|
|
|
|
|
|
|
|
|
|
104 |
del model
|
|
|
|
|
|
|
|
|
105 |
torch.cuda.empty_cache()
|
106 |
|
107 |
-
seperate_tokenizer = builders.
|
108 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
109 |
model = CodecLM(name = "tmp",
|
110 |
lm = None,
|
111 |
audiotokenizer = None,
|
112 |
max_duration = self.max_duration,
|
113 |
seperate_tokenizer = seperate_tokenizer,
|
114 |
)
|
115 |
-
|
116 |
with torch.no_grad():
|
117 |
if melody_is_wav:
|
118 |
-
wav_seperate = model.generate_audio(tokens, pmt_wav, vocal_wav, bgm_wav, gen_type=gen_type)
|
119 |
else:
|
120 |
-
wav_seperate = model.generate_audio(tokens, gen_type=gen_type)
|
121 |
|
122 |
-
|
123 |
-
|
|
|
124 |
torch.cuda.empty_cache()
|
125 |
|
126 |
return wav_seperate[0]
|
|
|
1 |
import os
|
2 |
+
import gc
|
3 |
import sys
|
4 |
|
5 |
import torch
|
|
|
13 |
from codeclm.models import builders
|
14 |
|
15 |
from separator import Separator
|
16 |
+
from codeclm.utils.offload_profiler import OffloadProfiler, OffloadParamParse
|
17 |
|
18 |
|
19 |
class LeVoInference(torch.nn.Module):
|
|
|
42 |
)
|
43 |
|
44 |
|
45 |
+
def forward(self, lyric: str, description: str = None, prompt_audio_path: os.PathLike = None, genre: str = None, auto_prompt_path: os.PathLike = None, gen_type: str = "mixed", params = dict()):
|
46 |
if prompt_audio_path is not None and os.path.exists(prompt_audio_path):
|
47 |
separator = Separator()
|
48 |
audio_tokenizer = builders.get_audio_tokenizer_model(self.cfg.audio_tokenizer_checkpoint, self.cfg)
|
49 |
audio_tokenizer = audio_tokenizer.eval().cuda()
|
|
|
|
|
50 |
pmt_wav, vocal_wav, bgm_wav = separator.run(prompt_audio_path)
|
51 |
pmt_wav = pmt_wav.cuda()
|
52 |
vocal_wav = vocal_wav.cuda()
|
53 |
bgm_wav = bgm_wav.cuda()
|
54 |
+
with torch.no_grad():
|
55 |
+
pmt_wav, _ = audio_tokenizer.encode(pmt_wav)
|
|
|
|
|
56 |
del audio_tokenizer
|
|
|
57 |
del separator
|
58 |
+
torch.cuda.empty_cache()
|
59 |
+
|
60 |
+
seperate_tokenizer = builders.get_audio_tokenizer_model(self.cfg.audio_tokenizer_checkpoint_sep, self.cfg)
|
61 |
+
seperate_tokenizer = seperate_tokenizer.eval().cuda()
|
62 |
+
with torch.no_grad():
|
63 |
+
vocal_wav, bgm_wav = seperate_tokenizer.encode(vocal_wav, bgm_wav)
|
64 |
+
del seperate_tokenizer
|
65 |
+
melody_is_wav = False
|
66 |
+
torch.cuda.empty_cache()
|
67 |
elif genre is not None and auto_prompt_path is not None:
|
68 |
auto_prompt = torch.load(auto_prompt_path)
|
69 |
merge_prompt = [item for sublist in auto_prompt.values() for item in sublist]
|
|
|
81 |
bgm_wav = None
|
82 |
melody_is_wav = True
|
83 |
|
84 |
+
audiolm = builders.get_lm_model(self.cfg)
|
85 |
+
checkpoint = torch.load(self.pt_path, map_location='cpu')
|
86 |
+
audiolm_state_dict = {k.replace('audiolm.', ''): v for k, v in checkpoint.items() if k.startswith('audiolm')}
|
87 |
+
audiolm.load_state_dict(audiolm_state_dict, strict=False)
|
88 |
+
audiolm = audiolm.eval()
|
89 |
+
|
90 |
+
offload_audiolm = True if 'offload' in self.cfg.keys() and 'audiolm' in self.cfg.offload else False
|
91 |
+
if offload_audiolm:
|
92 |
+
audiolm_offload_param = OffloadParamParse.parse_config(audiolm, self.cfg.offload.audiolm)
|
93 |
+
audiolm_offload_param.show()
|
94 |
+
offload_profiler = OffloadProfiler(device_index=0, **(audiolm_offload_param.init_param_dict()))
|
95 |
+
offload_profiler.offload_layer(**(audiolm_offload_param.offload_layer_param_dict()))
|
96 |
+
offload_profiler.clean_cache_wrapper(**(audiolm_offload_param.clean_cache_param_dict()))
|
97 |
+
else:
|
98 |
+
audiolm = audiolm.cuda().to(torch.float16)
|
99 |
+
|
100 |
model = CodecLM(name = "tmp",
|
101 |
+
lm = audiolm,
|
102 |
audiotokenizer = None,
|
103 |
max_duration = self.max_duration,
|
104 |
seperate_tokenizer = None,
|
105 |
)
|
|
|
|
|
106 |
params = {**self.default_params, **params}
|
107 |
model.set_generation_params(**params)
|
108 |
|
|
|
116 |
}
|
117 |
|
118 |
with torch.autocast(device_type="cuda", dtype=torch.float16):
|
119 |
+
with torch.no_grad():
|
120 |
+
tokens = model.generate(**generate_inp, return_tokens=True)
|
121 |
+
if offload_audiolm:
|
122 |
+
offload_profiler.reset_empty_cache_mem_line()
|
123 |
+
offload_profiler.stop()
|
124 |
+
del offload_profiler
|
125 |
+
del audiolm_offload_param
|
126 |
del model
|
127 |
+
audiolm = audiolm.cpu()
|
128 |
+
del audiolm
|
129 |
+
del checkpoint
|
130 |
+
gc.collect()
|
131 |
torch.cuda.empty_cache()
|
132 |
|
133 |
+
seperate_tokenizer = builders.get_audio_tokenizer_model_cpu(self.cfg.audio_tokenizer_checkpoint_sep, self.cfg)
|
134 |
+
device = "cuda:0"
|
135 |
+
seperate_tokenizer.model.device = device
|
136 |
+
seperate_tokenizer.model.vae = seperate_tokenizer.model.vae.to(device)
|
137 |
+
seperate_tokenizer.model.model.device = torch.device(device)
|
138 |
+
seperate_tokenizer = seperate_tokenizer.eval()
|
139 |
+
|
140 |
+
offload_wav_tokenizer_diffusion = True if 'offload' in self.cfg.keys() and 'wav_tokenizer_diffusion' in self.cfg.offload else False
|
141 |
+
if offload_wav_tokenizer_diffusion:
|
142 |
+
sep_offload_param = OffloadParamParse.parse_config(seperate_tokenizer, self.cfg.offload.wav_tokenizer_diffusion)
|
143 |
+
sep_offload_param.show()
|
144 |
+
sep_offload_profiler = OffloadProfiler(device_index=0, **(sep_offload_param.init_param_dict()))
|
145 |
+
sep_offload_profiler.offload_layer(**(sep_offload_param.offload_layer_param_dict()))
|
146 |
+
sep_offload_profiler.clean_cache_wrapper(**(sep_offload_param.clean_cache_param_dict()))
|
147 |
+
else:
|
148 |
+
seperate_tokenizer.model.model = seperate_tokenizer.model.model.to(device)
|
149 |
+
|
150 |
model = CodecLM(name = "tmp",
|
151 |
lm = None,
|
152 |
audiotokenizer = None,
|
153 |
max_duration = self.max_duration,
|
154 |
seperate_tokenizer = seperate_tokenizer,
|
155 |
)
|
156 |
+
|
157 |
with torch.no_grad():
|
158 |
if melody_is_wav:
|
159 |
+
wav_seperate = model.generate_audio(tokens, pmt_wav, vocal_wav, bgm_wav, gen_type=gen_type, chunked=True)
|
160 |
else:
|
161 |
+
wav_seperate = model.generate_audio(tokens, gen_type=gen_type, chunked=True)
|
162 |
|
163 |
+
if offload_wav_tokenizer_diffusion:
|
164 |
+
sep_offload_profiler.reset_empty_cache_mem_line()
|
165 |
+
sep_offload_profiler.stop()
|
166 |
torch.cuda.empty_cache()
|
167 |
|
168 |
return wav_seperate[0]
|