Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
c4fe16f
1
Parent(s):
04be12f
sync from github
Browse files- .gitattributes +1 -0
- .gitignore +2 -1
- examples/cases.jsonl +5 -5
- indextts/cli.py +7 -4
- indextts/infer.py +73 -60
- indextts/infer_v2.py +144 -98
- indextts/s2mel/modules/openvoice/api.py +4 -4
- indextts/s2mel/modules/openvoice/openvoice_app.py +1 -1
- indextts/s2mel/modules/openvoice/utils.py +35 -35
- indextts/utils/front.py +53 -53
- indextts/utils/maskgct/models/codec/facodec/modules/JDC/bst.t7 +3 -0
- tools/i18n/locale/en_US.json +24 -21
- tools/i18n/locale/zh_CN.json +7 -3
- webui.py +149 -76
.gitattributes
CHANGED
@@ -47,3 +47,4 @@ examples/emo_hate.wav filter=lfs diff=lfs merge=lfs -text
|
|
47 |
examples/voice_01.wav filter=lfs diff=lfs merge=lfs -text
|
48 |
examples/voice_03.wav filter=lfs diff=lfs merge=lfs -text
|
49 |
examples/voice_04.wav filter=lfs diff=lfs merge=lfs -text
|
|
|
|
47 |
examples/voice_01.wav filter=lfs diff=lfs merge=lfs -text
|
48 |
examples/voice_03.wav filter=lfs diff=lfs merge=lfs -text
|
49 |
examples/voice_04.wav filter=lfs diff=lfs merge=lfs -text
|
50 |
+
indextts/utils/maskgct/models/codec/facodec/modules/JDC/bst.t7 filter=lfs diff=lfs merge=lfs -text
|
.gitignore
CHANGED
@@ -13,4 +13,5 @@ build/
|
|
13 |
*.py[cod]
|
14 |
*.egg-info/
|
15 |
.venv
|
16 |
-
checkpoints/*
|
|
|
|
13 |
*.py[cod]
|
14 |
*.egg-info/
|
15 |
.venv
|
16 |
+
checkpoints/*
|
17 |
+
__MACOSX
|
examples/cases.jsonl
CHANGED
@@ -1,12 +1,12 @@
|
|
1 |
-
{"prompt_audio":"voice_01.wav","text":"Translate for me
|
2 |
{"prompt_audio":"voice_02.wav","text":"The palace is strict, no false rumors, Lady Qi!","emo_mode":0}
|
3 |
{"prompt_audio":"voice_03.wav","text":"这个呀,就是我们精心制作准备的纪念品,大家可以看到这个色泽和这个材质啊,哎呀多么的光彩照人。","emo_mode":0}
|
4 |
{"prompt_audio":"voice_04.wav","text":"你就需要我这种专业人士的帮助,就像手无缚鸡之力的人进入雪山狩猎,一定需要最老练的猎人指导。","emo_mode":0}
|
5 |
{"prompt_audio":"voice_05.wav","text":"在真正的日本剑道中,格斗过程极其短暂,常常短至半秒,最长也不超过两秒,利剑相击的转瞬间,已有一方倒在血泊中。但在这电光石火的对决之前,双方都要以一个石雕般凝固的姿势站定,长时间的逼视对方,这一过程可能长达十分钟!","emo_mode":0}
|
6 |
{"prompt_audio":"voice_06.wav","text":"今天呢,咱们开一部新书,叫《赛博朋克二零七七》。这词儿我听着都新鲜。这赛博朋克啊,简单理解就是“高科技,低生活”。这一听,我就明白了,于老师就爱用那高科技的东西,手机都得拿脚纹开,大冬天为了解锁脱得一丝不挂,冻得跟王八蛋似的。","emo_mode":0}
|
7 |
-
{"prompt_audio":"voice_07.wav","emo_audio":"emo_sad.wav","emo_weight": 0
|
8 |
-
{"prompt_audio":"voice_08.wav","emo_audio":"emo_hate.wav","emo_weight": 0
|
9 |
-
{"prompt_audio":"voice_09.wav","emo_vec_3":0.
|
10 |
-
{"prompt_audio":"voice_10.wav","emo_vec_7":0
|
11 |
{"prompt_audio":"voice_11.wav","emo_mode":3,"emo_text":"极度悲伤","text":"这些年的时光终究是错付了... "}
|
12 |
{"prompt_audio":"voice_12.wav","emo_mode":3,"emo_text":"You scared me to death! What are you, a ghost?","text":"快躲起来!是他要来了!他要来抓我们了!"}
|
|
|
1 |
+
{"prompt_audio":"voice_01.wav","text":"Translate for me, what is a surprise!","emo_mode":0}
|
2 |
{"prompt_audio":"voice_02.wav","text":"The palace is strict, no false rumors, Lady Qi!","emo_mode":0}
|
3 |
{"prompt_audio":"voice_03.wav","text":"这个呀,就是我们精心制作准备的纪念品,大家可以看到这个色泽和这个材质啊,哎呀多么的光彩照人。","emo_mode":0}
|
4 |
{"prompt_audio":"voice_04.wav","text":"你就需要我这种专业人士的帮助,就像手无缚鸡之力的人进入雪山狩猎,一定需要最老练的猎人指导。","emo_mode":0}
|
5 |
{"prompt_audio":"voice_05.wav","text":"在真正的日本剑道中,格斗过程极其短暂,常常短至半秒,最长也不超过两秒,利剑相击的转瞬间,已有一方倒在血泊中。但在这电光石火的对决之前,双方都要以一个石雕般凝固的姿势站定,长时间的逼视对方,这一过程可能长达十分钟!","emo_mode":0}
|
6 |
{"prompt_audio":"voice_06.wav","text":"今天呢,咱们开一部新书,叫《赛博朋克二零七七》。这词儿我听着都新鲜。这赛博朋克啊,简单理解就是“高科技,低生活”。这一听,我就明白了,于老师就爱用那高科技的东西,手机都得拿脚纹开,大冬天为了解锁脱得一丝不挂,冻得跟王八蛋似的。","emo_mode":0}
|
7 |
+
{"prompt_audio":"voice_07.wav","emo_audio":"emo_sad.wav","emo_weight": 1.0, "emo_mode":1,"text":"酒楼丧尽天良,开始借机竞拍房间,哎,一群蠢货。"}
|
8 |
+
{"prompt_audio":"voice_08.wav","emo_audio":"emo_hate.wav","emo_weight": 1.0, "emo_mode":1,"text":"你看看你,对我还有没有一点父子之间的信任了。"}
|
9 |
+
{"prompt_audio":"voice_09.wav","emo_vec_3":0.8,"emo_mode":2,"text":"对不起嘛!我的记性真的不太好,但是和你在一起的事情,我都会努力记住的~"}
|
10 |
+
{"prompt_audio":"voice_10.wav","emo_vec_7":1.0,"emo_mode":2,"text":"哇塞!这个爆率也太高了!欧皇附体了!"}
|
11 |
{"prompt_audio":"voice_11.wav","emo_mode":3,"emo_text":"极度悲伤","text":"这些年的时光终究是错付了... "}
|
12 |
{"prompt_audio":"voice_12.wav","emo_mode":3,"emo_text":"You scared me to death! What are you, a ghost?","text":"快躲起来!是他要来了!他要来抓我们了!"}
|
indextts/cli.py
CHANGED
@@ -12,9 +12,9 @@ def main():
|
|
12 |
parser.add_argument("-o", "--output_path", type=str, default="gen.wav", help="Path to the output wav file")
|
13 |
parser.add_argument("-c", "--config", type=str, default="checkpoints/config.yaml", help="Path to the config file. Default is 'checkpoints/config.yaml'")
|
14 |
parser.add_argument("--model_dir", type=str, default="checkpoints", help="Path to the model directory. Default is 'checkpoints'")
|
15 |
-
parser.add_argument("--fp16", action="store_true", default=
|
16 |
parser.add_argument("-f", "--force", action="store_true", default=False, help="Force to overwrite the output file if it exists")
|
17 |
-
parser.add_argument("-d", "--device", type=str, default=None, help="Device to run the model on (cpu, cuda, mps)." )
|
18 |
args = parser.parse_args()
|
19 |
if len(args.text.strip()) == 0:
|
20 |
print("ERROR: Text is empty.")
|
@@ -47,15 +47,18 @@ def main():
|
|
47 |
if args.device is None:
|
48 |
if torch.cuda.is_available():
|
49 |
args.device = "cuda:0"
|
50 |
-
elif torch.
|
|
|
|
|
51 |
args.device = "mps"
|
52 |
else:
|
53 |
args.device = "cpu"
|
54 |
args.fp16 = False # Disable FP16 on CPU
|
55 |
print("WARNING: Running on CPU may be slow.")
|
56 |
|
|
|
57 |
from indextts.infer import IndexTTS
|
58 |
-
tts = IndexTTS(cfg_path=args.config, model_dir=args.model_dir,
|
59 |
tts.infer(audio_prompt=args.voice, text=args.text.strip(), output_path=output_path)
|
60 |
|
61 |
if __name__ == "__main__":
|
|
|
12 |
parser.add_argument("-o", "--output_path", type=str, default="gen.wav", help="Path to the output wav file")
|
13 |
parser.add_argument("-c", "--config", type=str, default="checkpoints/config.yaml", help="Path to the config file. Default is 'checkpoints/config.yaml'")
|
14 |
parser.add_argument("--model_dir", type=str, default="checkpoints", help="Path to the model directory. Default is 'checkpoints'")
|
15 |
+
parser.add_argument("--fp16", action="store_true", default=False, help="Use FP16 for inference if available")
|
16 |
parser.add_argument("-f", "--force", action="store_true", default=False, help="Force to overwrite the output file if it exists")
|
17 |
+
parser.add_argument("-d", "--device", type=str, default=None, help="Device to run the model on (cpu, cuda, mps, xpu)." )
|
18 |
args = parser.parse_args()
|
19 |
if len(args.text.strip()) == 0:
|
20 |
print("ERROR: Text is empty.")
|
|
|
47 |
if args.device is None:
|
48 |
if torch.cuda.is_available():
|
49 |
args.device = "cuda:0"
|
50 |
+
elif hasattr(torch, "xpu") and torch.xpu.is_available():
|
51 |
+
args.device = "xpu"
|
52 |
+
elif hasattr(torch, "mps") and torch.mps.is_available():
|
53 |
args.device = "mps"
|
54 |
else:
|
55 |
args.device = "cpu"
|
56 |
args.fp16 = False # Disable FP16 on CPU
|
57 |
print("WARNING: Running on CPU may be slow.")
|
58 |
|
59 |
+
# TODO: Add CLI support for IndexTTS2.
|
60 |
from indextts.infer import IndexTTS
|
61 |
+
tts = IndexTTS(cfg_path=args.config, model_dir=args.model_dir, use_fp16=args.fp16, device=args.device)
|
62 |
tts.infer(audio_prompt=args.voice, text=args.text.strip(), output_path=output_path)
|
63 |
|
64 |
if __name__ == "__main__":
|
indextts/infer.py
CHANGED
@@ -26,38 +26,42 @@ from indextts.utils.front import TextNormalizer, TextTokenizer
|
|
26 |
|
27 |
class IndexTTS:
|
28 |
def __init__(
|
29 |
-
self, cfg_path="checkpoints/config.yaml", model_dir="checkpoints",
|
30 |
use_cuda_kernel=None,
|
31 |
):
|
32 |
"""
|
33 |
Args:
|
34 |
cfg_path (str): path to the config file.
|
35 |
model_dir (str): path to the model directory.
|
36 |
-
|
37 |
device (str): device to use (e.g., 'cuda:0', 'cpu'). If None, it will be set automatically based on the availability of CUDA or MPS.
|
38 |
use_cuda_kernel (None | bool): whether to use BigVGan custom fused activation CUDA kernel, only for CUDA device.
|
39 |
"""
|
40 |
if device is not None:
|
41 |
self.device = device
|
42 |
-
self.
|
43 |
self.use_cuda_kernel = use_cuda_kernel is not None and use_cuda_kernel and device.startswith("cuda")
|
44 |
elif torch.cuda.is_available():
|
45 |
self.device = "cuda:0"
|
46 |
-
self.
|
47 |
self.use_cuda_kernel = use_cuda_kernel is None or use_cuda_kernel
|
|
|
|
|
|
|
|
|
48 |
elif hasattr(torch, "mps") and torch.backends.mps.is_available():
|
49 |
self.device = "mps"
|
50 |
-
self.
|
51 |
self.use_cuda_kernel = False
|
52 |
else:
|
53 |
self.device = "cpu"
|
54 |
-
self.
|
55 |
self.use_cuda_kernel = False
|
56 |
print(">> Be patient, it may take a while to run in CPU mode.")
|
57 |
|
58 |
self.cfg = OmegaConf.load(cfg_path)
|
59 |
self.model_dir = model_dir
|
60 |
-
self.dtype = torch.float16 if self.
|
61 |
self.stop_mel_token = self.cfg.gpt.stop_mel_token
|
62 |
|
63 |
# Comment-off to load the VQ-VAE model for debugging tokenizer
|
@@ -68,7 +72,7 @@ class IndexTTS:
|
|
68 |
# self.dvae_path = os.path.join(self.model_dir, self.cfg.dvae_checkpoint)
|
69 |
# load_checkpoint(self.dvae, self.dvae_path)
|
70 |
# self.dvae = self.dvae.to(self.device)
|
71 |
-
# if self.
|
72 |
# self.dvae.eval().half()
|
73 |
# else:
|
74 |
# self.dvae.eval()
|
@@ -77,12 +81,12 @@ class IndexTTS:
|
|
77 |
self.gpt_path = os.path.join(self.model_dir, self.cfg.gpt_checkpoint)
|
78 |
load_checkpoint(self.gpt, self.gpt_path)
|
79 |
self.gpt = self.gpt.to(self.device)
|
80 |
-
if self.
|
81 |
self.gpt.eval().half()
|
82 |
else:
|
83 |
self.gpt.eval()
|
84 |
print(">> GPT weights restored from:", self.gpt_path)
|
85 |
-
if self.
|
86 |
try:
|
87 |
import deepspeed
|
88 |
|
@@ -184,17 +188,17 @@ class IndexTTS:
|
|
184 |
code_lens = torch.tensor(code_lens, dtype=torch.long, device=device)
|
185 |
return codes, code_lens
|
186 |
|
187 |
-
def
|
188 |
"""
|
189 |
-
|
190 |
-
if ``bucket_max_size=1``, return all
|
191 |
"""
|
192 |
outputs: List[Dict] = []
|
193 |
-
for idx, sent in enumerate(
|
194 |
outputs.append({"idx": idx, "sent": sent, "len": len(sent)})
|
195 |
|
196 |
if len(outputs) > bucket_max_size:
|
197 |
-
# split
|
198 |
buckets: List[List[Dict]] = []
|
199 |
factor = 1.5
|
200 |
last_bucket = None
|
@@ -203,7 +207,7 @@ class IndexTTS:
|
|
203 |
for sent in sorted(outputs, key=lambda x: x["len"]):
|
204 |
current_sent_len = sent["len"]
|
205 |
if current_sent_len == 0:
|
206 |
-
print(">> skip empty
|
207 |
continue
|
208 |
if last_bucket is None \
|
209 |
or current_sent_len >= int(last_bucket_sent_len_median * factor) \
|
@@ -213,7 +217,7 @@ class IndexTTS:
|
|
213 |
last_bucket = buckets[-1]
|
214 |
last_bucket_sent_len_median = current_sent_len
|
215 |
else:
|
216 |
-
# current bucket can hold more
|
217 |
last_bucket.append(sent) # sorted
|
218 |
mid = len(last_bucket) // 2
|
219 |
last_bucket_sent_len_median = last_bucket[mid]["len"]
|
@@ -276,20 +280,20 @@ class IndexTTS:
|
|
276 |
self.gr_progress(value, desc=desc)
|
277 |
|
278 |
# 快速推理:对于“多句长文本”,可实现至少 2~10 倍以上的速度提升~ (First modified by sunnyboxs 2025-04-16)
|
279 |
-
def infer_fast(self, audio_prompt, text, output_path, verbose=False,
|
280 |
-
|
281 |
"""
|
282 |
Args:
|
283 |
-
``
|
284 |
- 越小,batch 越多,推理速度越*快*,占用内存更多,可能影响质量
|
285 |
- 越大,batch 越少,推理速度越*慢*,占用内存和质量更接近于非快速推理
|
286 |
-
``
|
287 |
- 越大,bucket数量越少,batch越多,推理速度越*快*,占用内存更多,可能影响质量
|
288 |
- 越小,bucket数量越多,batch越少,推理速度越*慢*,占用内存和质量更接近于非快速推理
|
289 |
"""
|
290 |
-
print(">>
|
291 |
|
292 |
-
self._set_gr_progress(0, "
|
293 |
if verbose:
|
294 |
print(f"origin text:{text}")
|
295 |
start_time = time.perf_counter()
|
@@ -301,6 +305,15 @@ class IndexTTS:
|
|
301 |
if audio.shape[0] > 1:
|
302 |
audio = audio[0].unsqueeze(0)
|
303 |
audio = torchaudio.transforms.Resample(sr, 24000)(audio)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
304 |
cond_mel = MelSpectrogramFeatures()(audio).to(self.device)
|
305 |
cond_mel_frame = cond_mel.shape[-1]
|
306 |
if verbose:
|
@@ -319,13 +332,13 @@ class IndexTTS:
|
|
319 |
# text_tokens
|
320 |
text_tokens_list = self.tokenizer.tokenize(text)
|
321 |
|
322 |
-
|
323 |
-
|
324 |
if verbose:
|
325 |
print(">> text token count:", len(text_tokens_list))
|
326 |
-
print("
|
327 |
-
print("
|
328 |
-
print(*
|
329 |
do_sample = generation_kwargs.pop("do_sample", True)
|
330 |
top_p = generation_kwargs.pop("top_p", 0.8)
|
331 |
top_k = generation_kwargs.pop("top_k", 30)
|
@@ -346,17 +359,17 @@ class IndexTTS:
|
|
346 |
# text processing
|
347 |
all_text_tokens: List[List[torch.Tensor]] = []
|
348 |
self._set_gr_progress(0.1, "text processing...")
|
349 |
-
bucket_max_size =
|
350 |
-
|
351 |
-
bucket_count = len(
|
352 |
if verbose:
|
353 |
-
print(">>
|
354 |
-
"bucket sizes:", [(len(s), [t["idx"] for t in s]) for s in
|
355 |
"bucket_max_size:", bucket_max_size)
|
356 |
-
for
|
357 |
temp_tokens: List[torch.Tensor] = []
|
358 |
all_text_tokens.append(temp_tokens)
|
359 |
-
for item in
|
360 |
sent = item["sent"]
|
361 |
text_tokens = self.tokenizer.convert_tokens_to_ids(sent)
|
362 |
text_tokens = torch.tensor(text_tokens, dtype=torch.int32, device=self.device).unsqueeze(0)
|
@@ -365,11 +378,11 @@ class IndexTTS:
|
|
365 |
print(f"text_tokens shape: {text_tokens.shape}, text_tokens type: {text_tokens.dtype}")
|
366 |
# debug tokenizer
|
367 |
text_token_syms = self.tokenizer.convert_ids_to_tokens(text_tokens[0].tolist())
|
368 |
-
print("text_token_syms is same as
|
369 |
temp_tokens.append(text_tokens)
|
370 |
|
371 |
# Sequential processing of bucketing data
|
372 |
-
all_batch_num = sum(len(s) for s in
|
373 |
all_batch_codes = []
|
374 |
processed_num = 0
|
375 |
for item_tokens in all_text_tokens:
|
@@ -381,7 +394,7 @@ class IndexTTS:
|
|
381 |
processed_num += batch_num
|
382 |
# gpt speech
|
383 |
self._set_gr_progress(0.2 + 0.3 * processed_num / all_batch_num,
|
384 |
-
f"gpt inference
|
385 |
m_start_time = time.perf_counter()
|
386 |
with torch.no_grad():
|
387 |
with torch.amp.autocast(batch_text_tokens.device.type, enabled=self.dtype is not None,
|
@@ -403,17 +416,17 @@ class IndexTTS:
|
|
403 |
gpt_gen_time += time.perf_counter() - m_start_time
|
404 |
|
405 |
# gpt latent
|
406 |
-
self._set_gr_progress(0.5, "gpt inference
|
407 |
all_idxs = []
|
408 |
all_latents = []
|
409 |
has_warned = False
|
410 |
-
for batch_codes, batch_tokens,
|
411 |
for i in range(batch_codes.shape[0]):
|
412 |
codes = batch_codes[i] # [x]
|
413 |
if not has_warned and codes[-1] != self.stop_mel_token:
|
414 |
warnings.warn(
|
415 |
f"WARN: generation stopped due to exceeding `max_mel_tokens` ({max_mel_tokens}). "
|
416 |
-
f"Consider reducing `
|
417 |
category=RuntimeWarning
|
418 |
)
|
419 |
has_warned = True
|
@@ -427,7 +440,7 @@ class IndexTTS:
|
|
427 |
print(codes)
|
428 |
print("code_lens:", code_lens)
|
429 |
text_tokens = batch_tokens[i]
|
430 |
-
all_idxs.append(
|
431 |
m_start_time = time.perf_counter()
|
432 |
with torch.no_grad():
|
433 |
with torch.amp.autocast(text_tokens.device.type, enabled=self.dtype is not None, dtype=self.dtype):
|
@@ -440,7 +453,7 @@ class IndexTTS:
|
|
440 |
return_latent=True, clip_inputs=False)
|
441 |
gpt_forward_time += time.perf_counter() - m_start_time
|
442 |
all_latents.append(latent)
|
443 |
-
del all_batch_codes, all_text_tokens,
|
444 |
# bigvgan chunk
|
445 |
chunk_size = 2
|
446 |
all_latents = [all_latents[all_idxs.index(i)] for i in range(len(all_latents))]
|
@@ -452,7 +465,7 @@ class IndexTTS:
|
|
452 |
latent_length = len(all_latents)
|
453 |
|
454 |
# bigvgan chunk decode
|
455 |
-
self._set_gr_progress(0.7, "bigvgan
|
456 |
tqdm_progress = tqdm(total=latent_length, desc="bigvgan")
|
457 |
for items in chunk_latents:
|
458 |
tqdm_progress.update(len(items))
|
@@ -474,7 +487,7 @@ class IndexTTS:
|
|
474 |
self.torch_empty_cache()
|
475 |
|
476 |
# wav audio output
|
477 |
-
self._set_gr_progress(0.9, "
|
478 |
wav = torch.cat(wavs, dim=1)
|
479 |
wav_length = wav.shape[-1] / sampling_rate
|
480 |
print(f">> Reference audio length: {cond_mel_frame * 256 / sampling_rate:.2f} seconds")
|
@@ -503,10 +516,10 @@ class IndexTTS:
|
|
503 |
return (sampling_rate, wav_data)
|
504 |
|
505 |
# 原始推理模式
|
506 |
-
def infer(self, audio_prompt, text, output_path, verbose=False,
|
507 |
**generation_kwargs):
|
508 |
-
print(">>
|
509 |
-
self._set_gr_progress(0, "
|
510 |
if verbose:
|
511 |
print(f"origin text:{text}")
|
512 |
start_time = time.perf_counter()
|
@@ -533,12 +546,12 @@ class IndexTTS:
|
|
533 |
self._set_gr_progress(0.1, "text processing...")
|
534 |
auto_conditioning = cond_mel
|
535 |
text_tokens_list = self.tokenizer.tokenize(text)
|
536 |
-
|
537 |
if verbose:
|
538 |
print("text token count:", len(text_tokens_list))
|
539 |
-
print("
|
540 |
-
print("
|
541 |
-
print(*
|
542 |
do_sample = generation_kwargs.pop("do_sample", True)
|
543 |
top_p = generation_kwargs.pop("top_p", 0.8)
|
544 |
top_k = generation_kwargs.pop("top_k", 30)
|
@@ -557,7 +570,7 @@ class IndexTTS:
|
|
557 |
bigvgan_time = 0
|
558 |
progress = 0
|
559 |
has_warned = False
|
560 |
-
for sent in
|
561 |
text_tokens = self.tokenizer.convert_tokens_to_ids(sent)
|
562 |
text_tokens = torch.tensor(text_tokens, dtype=torch.int32, device=self.device).unsqueeze(0)
|
563 |
# text_tokens = F.pad(text_tokens, (0, 1)) # This may not be necessary.
|
@@ -568,13 +581,13 @@ class IndexTTS:
|
|
568 |
print(f"text_tokens shape: {text_tokens.shape}, text_tokens type: {text_tokens.dtype}")
|
569 |
# debug tokenizer
|
570 |
text_token_syms = self.tokenizer.convert_ids_to_tokens(text_tokens[0].tolist())
|
571 |
-
print("text_token_syms is same as
|
572 |
|
573 |
# text_len = torch.IntTensor([text_tokens.size(1)], device=text_tokens.device)
|
574 |
# print(text_len)
|
575 |
progress += 1
|
576 |
-
self._set_gr_progress(0.2 + 0.4 * (progress - 1) / len(
|
577 |
-
f"gpt inference
|
578 |
m_start_time = time.perf_counter()
|
579 |
with torch.no_grad():
|
580 |
with torch.amp.autocast(text_tokens.device.type, enabled=self.dtype is not None, dtype=self.dtype):
|
@@ -597,7 +610,7 @@ class IndexTTS:
|
|
597 |
warnings.warn(
|
598 |
f"WARN: generation stopped due to exceeding `max_mel_tokens` ({max_mel_tokens}). "
|
599 |
f"Input text tokens: {text_tokens.shape[1]}. "
|
600 |
-
f"Consider reducing `
|
601 |
category=RuntimeWarning
|
602 |
)
|
603 |
has_warned = True
|
@@ -615,8 +628,8 @@ class IndexTTS:
|
|
615 |
print(codes, type(codes))
|
616 |
print(f"fix codes shape: {codes.shape}, codes type: {codes.dtype}")
|
617 |
print(f"code len: {code_lens}")
|
618 |
-
self._set_gr_progress(0.2 + 0.4 * progress / len(
|
619 |
-
f"gpt inference
|
620 |
m_start_time = time.perf_counter()
|
621 |
# latent, text_lens_out, code_lens_out = \
|
622 |
with torch.amp.autocast(text_tokens.device.type, enabled=self.dtype is not None, dtype=self.dtype):
|
@@ -640,7 +653,7 @@ class IndexTTS:
|
|
640 |
# wavs.append(wav[:, :-512])
|
641 |
wavs.append(wav.cpu()) # to cpu before saving
|
642 |
end_time = time.perf_counter()
|
643 |
-
self._set_gr_progress(0.9, "
|
644 |
wav = torch.cat(wavs, dim=1)
|
645 |
wav_length = wav.shape[-1] / sampling_rate
|
646 |
print(f">> Reference audio length: {cond_mel_frame * 256 / sampling_rate:.2f} seconds")
|
|
|
26 |
|
27 |
class IndexTTS:
|
28 |
def __init__(
|
29 |
+
self, cfg_path="checkpoints/config.yaml", model_dir="checkpoints", use_fp16=True, device=None,
|
30 |
use_cuda_kernel=None,
|
31 |
):
|
32 |
"""
|
33 |
Args:
|
34 |
cfg_path (str): path to the config file.
|
35 |
model_dir (str): path to the model directory.
|
36 |
+
use_fp16 (bool): whether to use fp16.
|
37 |
device (str): device to use (e.g., 'cuda:0', 'cpu'). If None, it will be set automatically based on the availability of CUDA or MPS.
|
38 |
use_cuda_kernel (None | bool): whether to use BigVGan custom fused activation CUDA kernel, only for CUDA device.
|
39 |
"""
|
40 |
if device is not None:
|
41 |
self.device = device
|
42 |
+
self.use_fp16 = False if device == "cpu" else use_fp16
|
43 |
self.use_cuda_kernel = use_cuda_kernel is not None and use_cuda_kernel and device.startswith("cuda")
|
44 |
elif torch.cuda.is_available():
|
45 |
self.device = "cuda:0"
|
46 |
+
self.use_fp16 = use_fp16
|
47 |
self.use_cuda_kernel = use_cuda_kernel is None or use_cuda_kernel
|
48 |
+
elif hasattr(torch, "xpu") and torch.xpu.is_available():
|
49 |
+
self.device = "xpu"
|
50 |
+
self.use_fp16 = use_fp16
|
51 |
+
self.use_cuda_kernel = False
|
52 |
elif hasattr(torch, "mps") and torch.backends.mps.is_available():
|
53 |
self.device = "mps"
|
54 |
+
self.use_fp16 = False # Use float16 on MPS is overhead than float32
|
55 |
self.use_cuda_kernel = False
|
56 |
else:
|
57 |
self.device = "cpu"
|
58 |
+
self.use_fp16 = False
|
59 |
self.use_cuda_kernel = False
|
60 |
print(">> Be patient, it may take a while to run in CPU mode.")
|
61 |
|
62 |
self.cfg = OmegaConf.load(cfg_path)
|
63 |
self.model_dir = model_dir
|
64 |
+
self.dtype = torch.float16 if self.use_fp16 else None
|
65 |
self.stop_mel_token = self.cfg.gpt.stop_mel_token
|
66 |
|
67 |
# Comment-off to load the VQ-VAE model for debugging tokenizer
|
|
|
72 |
# self.dvae_path = os.path.join(self.model_dir, self.cfg.dvae_checkpoint)
|
73 |
# load_checkpoint(self.dvae, self.dvae_path)
|
74 |
# self.dvae = self.dvae.to(self.device)
|
75 |
+
# if self.use_fp16:
|
76 |
# self.dvae.eval().half()
|
77 |
# else:
|
78 |
# self.dvae.eval()
|
|
|
81 |
self.gpt_path = os.path.join(self.model_dir, self.cfg.gpt_checkpoint)
|
82 |
load_checkpoint(self.gpt, self.gpt_path)
|
83 |
self.gpt = self.gpt.to(self.device)
|
84 |
+
if self.use_fp16:
|
85 |
self.gpt.eval().half()
|
86 |
else:
|
87 |
self.gpt.eval()
|
88 |
print(">> GPT weights restored from:", self.gpt_path)
|
89 |
+
if self.use_fp16:
|
90 |
try:
|
91 |
import deepspeed
|
92 |
|
|
|
188 |
code_lens = torch.tensor(code_lens, dtype=torch.long, device=device)
|
189 |
return codes, code_lens
|
190 |
|
191 |
+
def bucket_segments(self, segments, bucket_max_size=4) -> List[List[Dict]]:
|
192 |
"""
|
193 |
+
Segment data bucketing.
|
194 |
+
if ``bucket_max_size=1``, return all segments in one bucket.
|
195 |
"""
|
196 |
outputs: List[Dict] = []
|
197 |
+
for idx, sent in enumerate(segments):
|
198 |
outputs.append({"idx": idx, "sent": sent, "len": len(sent)})
|
199 |
|
200 |
if len(outputs) > bucket_max_size:
|
201 |
+
# split segments into buckets by segment length
|
202 |
buckets: List[List[Dict]] = []
|
203 |
factor = 1.5
|
204 |
last_bucket = None
|
|
|
207 |
for sent in sorted(outputs, key=lambda x: x["len"]):
|
208 |
current_sent_len = sent["len"]
|
209 |
if current_sent_len == 0:
|
210 |
+
print(">> skip empty segment")
|
211 |
continue
|
212 |
if last_bucket is None \
|
213 |
or current_sent_len >= int(last_bucket_sent_len_median * factor) \
|
|
|
217 |
last_bucket = buckets[-1]
|
218 |
last_bucket_sent_len_median = current_sent_len
|
219 |
else:
|
220 |
+
# current bucket can hold more segments
|
221 |
last_bucket.append(sent) # sorted
|
222 |
mid = len(last_bucket) // 2
|
223 |
last_bucket_sent_len_median = last_bucket[mid]["len"]
|
|
|
280 |
self.gr_progress(value, desc=desc)
|
281 |
|
282 |
# 快速推理:对于“多句长文本”,可实现至少 2~10 倍以上的速度提升~ (First modified by sunnyboxs 2025-04-16)
|
283 |
+
def infer_fast(self, audio_prompt, text, output_path, verbose=False, max_text_tokens_per_segment=100,
|
284 |
+
segments_bucket_max_size=4, **generation_kwargs):
|
285 |
"""
|
286 |
Args:
|
287 |
+
``max_text_tokens_per_segment``: 分句的最大token数,默认``100``,可以根据GPU硬件情况调整
|
288 |
- 越小,batch 越多,推理速度越*快*,占用内存更多,可能影响质量
|
289 |
- 越大,batch 越少,推理速度越*慢*,占用内存和质量更接近于非快速推理
|
290 |
+
``segments_bucket_max_size``: 分句分桶的最大容量,默认``4``,可以根据GPU内存调整
|
291 |
- 越大,bucket数量越少,batch越多,推理速度越*快*,占用内存更多,可能影响质量
|
292 |
- 越小,bucket数量越多,batch越少,推理速度越*慢*,占用内存和质量更接近于非快速推理
|
293 |
"""
|
294 |
+
print(">> starting fast inference...")
|
295 |
|
296 |
+
self._set_gr_progress(0, "starting fast inference...")
|
297 |
if verbose:
|
298 |
print(f"origin text:{text}")
|
299 |
start_time = time.perf_counter()
|
|
|
305 |
if audio.shape[0] > 1:
|
306 |
audio = audio[0].unsqueeze(0)
|
307 |
audio = torchaudio.transforms.Resample(sr, 24000)(audio)
|
308 |
+
|
309 |
+
max_audio_length_seconds = 50
|
310 |
+
max_audio_samples = int(max_audio_length_seconds * 24000)
|
311 |
+
|
312 |
+
if audio.shape[1] > max_audio_samples:
|
313 |
+
if verbose:
|
314 |
+
print(f"Audio too long ({audio.shape[1]} samples), truncating to {max_audio_samples} samples")
|
315 |
+
audio = audio[:, :max_audio_samples]
|
316 |
+
|
317 |
cond_mel = MelSpectrogramFeatures()(audio).to(self.device)
|
318 |
cond_mel_frame = cond_mel.shape[-1]
|
319 |
if verbose:
|
|
|
332 |
# text_tokens
|
333 |
text_tokens_list = self.tokenizer.tokenize(text)
|
334 |
|
335 |
+
segments = self.tokenizer.split_segments(text_tokens_list,
|
336 |
+
max_text_tokens_per_segment=max_text_tokens_per_segment)
|
337 |
if verbose:
|
338 |
print(">> text token count:", len(text_tokens_list))
|
339 |
+
print(" segments count:", len(segments))
|
340 |
+
print(" max_text_tokens_per_segment:", max_text_tokens_per_segment)
|
341 |
+
print(*segments, sep="\n")
|
342 |
do_sample = generation_kwargs.pop("do_sample", True)
|
343 |
top_p = generation_kwargs.pop("top_p", 0.8)
|
344 |
top_k = generation_kwargs.pop("top_k", 30)
|
|
|
359 |
# text processing
|
360 |
all_text_tokens: List[List[torch.Tensor]] = []
|
361 |
self._set_gr_progress(0.1, "text processing...")
|
362 |
+
bucket_max_size = segments_bucket_max_size if self.device != "cpu" else 1
|
363 |
+
all_segments = self.bucket_segments(segments, bucket_max_size=bucket_max_size)
|
364 |
+
bucket_count = len(all_segments)
|
365 |
if verbose:
|
366 |
+
print(">> segments bucket_count:", bucket_count,
|
367 |
+
"bucket sizes:", [(len(s), [t["idx"] for t in s]) for s in all_segments],
|
368 |
"bucket_max_size:", bucket_max_size)
|
369 |
+
for segments in all_segments:
|
370 |
temp_tokens: List[torch.Tensor] = []
|
371 |
all_text_tokens.append(temp_tokens)
|
372 |
+
for item in segments:
|
373 |
sent = item["sent"]
|
374 |
text_tokens = self.tokenizer.convert_tokens_to_ids(sent)
|
375 |
text_tokens = torch.tensor(text_tokens, dtype=torch.int32, device=self.device).unsqueeze(0)
|
|
|
378 |
print(f"text_tokens shape: {text_tokens.shape}, text_tokens type: {text_tokens.dtype}")
|
379 |
# debug tokenizer
|
380 |
text_token_syms = self.tokenizer.convert_ids_to_tokens(text_tokens[0].tolist())
|
381 |
+
print("text_token_syms is same as segment tokens", text_token_syms == sent)
|
382 |
temp_tokens.append(text_tokens)
|
383 |
|
384 |
# Sequential processing of bucketing data
|
385 |
+
all_batch_num = sum(len(s) for s in all_segments)
|
386 |
all_batch_codes = []
|
387 |
processed_num = 0
|
388 |
for item_tokens in all_text_tokens:
|
|
|
394 |
processed_num += batch_num
|
395 |
# gpt speech
|
396 |
self._set_gr_progress(0.2 + 0.3 * processed_num / all_batch_num,
|
397 |
+
f"gpt speech inference {processed_num}/{all_batch_num}...")
|
398 |
m_start_time = time.perf_counter()
|
399 |
with torch.no_grad():
|
400 |
with torch.amp.autocast(batch_text_tokens.device.type, enabled=self.dtype is not None,
|
|
|
416 |
gpt_gen_time += time.perf_counter() - m_start_time
|
417 |
|
418 |
# gpt latent
|
419 |
+
self._set_gr_progress(0.5, "gpt latents inference...")
|
420 |
all_idxs = []
|
421 |
all_latents = []
|
422 |
has_warned = False
|
423 |
+
for batch_codes, batch_tokens, batch_segments in zip(all_batch_codes, all_text_tokens, all_segments):
|
424 |
for i in range(batch_codes.shape[0]):
|
425 |
codes = batch_codes[i] # [x]
|
426 |
if not has_warned and codes[-1] != self.stop_mel_token:
|
427 |
warnings.warn(
|
428 |
f"WARN: generation stopped due to exceeding `max_mel_tokens` ({max_mel_tokens}). "
|
429 |
+
f"Consider reducing `max_text_tokens_per_segment`({max_text_tokens_per_segment}) or increasing `max_mel_tokens`.",
|
430 |
category=RuntimeWarning
|
431 |
)
|
432 |
has_warned = True
|
|
|
440 |
print(codes)
|
441 |
print("code_lens:", code_lens)
|
442 |
text_tokens = batch_tokens[i]
|
443 |
+
all_idxs.append(batch_segments[i]["idx"])
|
444 |
m_start_time = time.perf_counter()
|
445 |
with torch.no_grad():
|
446 |
with torch.amp.autocast(text_tokens.device.type, enabled=self.dtype is not None, dtype=self.dtype):
|
|
|
453 |
return_latent=True, clip_inputs=False)
|
454 |
gpt_forward_time += time.perf_counter() - m_start_time
|
455 |
all_latents.append(latent)
|
456 |
+
del all_batch_codes, all_text_tokens, all_segments
|
457 |
# bigvgan chunk
|
458 |
chunk_size = 2
|
459 |
all_latents = [all_latents[all_idxs.index(i)] for i in range(len(all_latents))]
|
|
|
465 |
latent_length = len(all_latents)
|
466 |
|
467 |
# bigvgan chunk decode
|
468 |
+
self._set_gr_progress(0.7, "bigvgan decoding...")
|
469 |
tqdm_progress = tqdm(total=latent_length, desc="bigvgan")
|
470 |
for items in chunk_latents:
|
471 |
tqdm_progress.update(len(items))
|
|
|
487 |
self.torch_empty_cache()
|
488 |
|
489 |
# wav audio output
|
490 |
+
self._set_gr_progress(0.9, "saving audio...")
|
491 |
wav = torch.cat(wavs, dim=1)
|
492 |
wav_length = wav.shape[-1] / sampling_rate
|
493 |
print(f">> Reference audio length: {cond_mel_frame * 256 / sampling_rate:.2f} seconds")
|
|
|
516 |
return (sampling_rate, wav_data)
|
517 |
|
518 |
# 原始推理模式
|
519 |
+
def infer(self, audio_prompt, text, output_path, verbose=False, max_text_tokens_per_segment=120,
|
520 |
**generation_kwargs):
|
521 |
+
print(">> starting inference...")
|
522 |
+
self._set_gr_progress(0, "starting inference...")
|
523 |
if verbose:
|
524 |
print(f"origin text:{text}")
|
525 |
start_time = time.perf_counter()
|
|
|
546 |
self._set_gr_progress(0.1, "text processing...")
|
547 |
auto_conditioning = cond_mel
|
548 |
text_tokens_list = self.tokenizer.tokenize(text)
|
549 |
+
segments = self.tokenizer.split_segments(text_tokens_list, max_text_tokens_per_segment)
|
550 |
if verbose:
|
551 |
print("text token count:", len(text_tokens_list))
|
552 |
+
print("segments count:", len(segments))
|
553 |
+
print("max_text_tokens_per_segment:", max_text_tokens_per_segment)
|
554 |
+
print(*segments, sep="\n")
|
555 |
do_sample = generation_kwargs.pop("do_sample", True)
|
556 |
top_p = generation_kwargs.pop("top_p", 0.8)
|
557 |
top_k = generation_kwargs.pop("top_k", 30)
|
|
|
570 |
bigvgan_time = 0
|
571 |
progress = 0
|
572 |
has_warned = False
|
573 |
+
for sent in segments:
|
574 |
text_tokens = self.tokenizer.convert_tokens_to_ids(sent)
|
575 |
text_tokens = torch.tensor(text_tokens, dtype=torch.int32, device=self.device).unsqueeze(0)
|
576 |
# text_tokens = F.pad(text_tokens, (0, 1)) # This may not be necessary.
|
|
|
581 |
print(f"text_tokens shape: {text_tokens.shape}, text_tokens type: {text_tokens.dtype}")
|
582 |
# debug tokenizer
|
583 |
text_token_syms = self.tokenizer.convert_ids_to_tokens(text_tokens[0].tolist())
|
584 |
+
print("text_token_syms is same as segment tokens", text_token_syms == sent)
|
585 |
|
586 |
# text_len = torch.IntTensor([text_tokens.size(1)], device=text_tokens.device)
|
587 |
# print(text_len)
|
588 |
progress += 1
|
589 |
+
self._set_gr_progress(0.2 + 0.4 * (progress - 1) / len(segments),
|
590 |
+
f"gpt latents inference {progress}/{len(segments)}...")
|
591 |
m_start_time = time.perf_counter()
|
592 |
with torch.no_grad():
|
593 |
with torch.amp.autocast(text_tokens.device.type, enabled=self.dtype is not None, dtype=self.dtype):
|
|
|
610 |
warnings.warn(
|
611 |
f"WARN: generation stopped due to exceeding `max_mel_tokens` ({max_mel_tokens}). "
|
612 |
f"Input text tokens: {text_tokens.shape[1]}. "
|
613 |
+
f"Consider reducing `max_text_tokens_per_segment`({max_text_tokens_per_segment}) or increasing `max_mel_tokens`.",
|
614 |
category=RuntimeWarning
|
615 |
)
|
616 |
has_warned = True
|
|
|
628 |
print(codes, type(codes))
|
629 |
print(f"fix codes shape: {codes.shape}, codes type: {codes.dtype}")
|
630 |
print(f"code len: {code_lens}")
|
631 |
+
self._set_gr_progress(0.2 + 0.4 * progress / len(segments),
|
632 |
+
f"gpt speech inference {progress}/{len(segments)}...")
|
633 |
m_start_time = time.perf_counter()
|
634 |
# latent, text_lens_out, code_lens_out = \
|
635 |
with torch.amp.autocast(text_tokens.device.type, enabled=self.dtype is not None, dtype=self.dtype):
|
|
|
653 |
# wavs.append(wav[:, :-512])
|
654 |
wavs.append(wav.cpu()) # to cpu before saving
|
655 |
end_time = time.perf_counter()
|
656 |
+
self._set_gr_progress(0.9, "saving audio...")
|
657 |
wav = torch.cat(wavs, dim=1)
|
658 |
wav_length = wav.shape[-1] / sampling_rate
|
659 |
print(f">> Reference audio length: {cond_mel_frame * 256 / sampling_rate:.2f} seconds")
|
indextts/infer_v2.py
CHANGED
@@ -1,6 +1,9 @@
|
|
1 |
import os
|
2 |
from subprocess import CalledProcessError
|
3 |
|
|
|
|
|
|
|
4 |
import time
|
5 |
import librosa
|
6 |
import torch
|
@@ -34,38 +37,43 @@ import torch.nn.functional as F
|
|
34 |
|
35 |
class IndexTTS2:
|
36 |
def __init__(
|
37 |
-
self, cfg_path="checkpoints/config.yaml", model_dir="checkpoints",
|
38 |
use_cuda_kernel=None,use_deepspeed=False
|
39 |
):
|
40 |
"""
|
41 |
Args:
|
42 |
cfg_path (str): path to the config file.
|
43 |
model_dir (str): path to the model directory.
|
44 |
-
|
45 |
device (str): device to use (e.g., 'cuda:0', 'cpu'). If None, it will be set automatically based on the availability of CUDA or MPS.
|
46 |
use_cuda_kernel (None | bool): whether to use BigVGan custom fused activation CUDA kernel, only for CUDA device.
|
|
|
47 |
"""
|
48 |
if device is not None:
|
49 |
self.device = device
|
50 |
-
self.
|
51 |
self.use_cuda_kernel = use_cuda_kernel is not None and use_cuda_kernel and device.startswith("cuda")
|
52 |
elif torch.cuda.is_available():
|
53 |
self.device = "cuda:0"
|
54 |
-
self.
|
55 |
self.use_cuda_kernel = use_cuda_kernel is None or use_cuda_kernel
|
|
|
|
|
|
|
|
|
56 |
elif hasattr(torch, "mps") and torch.backends.mps.is_available():
|
57 |
self.device = "mps"
|
58 |
-
self.
|
59 |
self.use_cuda_kernel = False
|
60 |
else:
|
61 |
self.device = "cpu"
|
62 |
-
self.
|
63 |
self.use_cuda_kernel = False
|
64 |
print(">> Be patient, it may take a while to run in CPU mode.")
|
65 |
|
66 |
self.cfg = OmegaConf.load(cfg_path)
|
67 |
self.model_dir = model_dir
|
68 |
-
self.dtype = torch.float16 if self.
|
69 |
self.stop_mel_token = self.cfg.gpt.stop_mel_token
|
70 |
|
71 |
self.qwen_emo = QwenEmotion(os.path.join(self.model_dir, self.cfg.qwen_emo_path))
|
@@ -74,32 +82,30 @@ class IndexTTS2:
|
|
74 |
self.gpt_path = os.path.join(self.model_dir, self.cfg.gpt_checkpoint)
|
75 |
load_checkpoint(self.gpt, self.gpt_path)
|
76 |
self.gpt = self.gpt.to(self.device)
|
77 |
-
if self.
|
78 |
self.gpt.eval().half()
|
79 |
else:
|
80 |
self.gpt.eval()
|
81 |
print(">> GPT weights restored from:", self.gpt_path)
|
82 |
-
|
|
|
83 |
try:
|
84 |
import deepspeed
|
85 |
-
|
86 |
except (ImportError, OSError, CalledProcessError) as e:
|
87 |
use_deepspeed = False
|
88 |
-
print(f">> DeepSpeed
|
89 |
|
90 |
-
|
91 |
-
else:
|
92 |
-
self.gpt.post_init_gpt2_config(use_deepspeed=use_deepspeed, kv_cache=True, half=False)
|
93 |
|
94 |
if self.use_cuda_kernel:
|
95 |
# preload the CUDA kernel for BigVGAN
|
96 |
try:
|
97 |
-
from indextts.
|
98 |
|
99 |
-
|
100 |
-
|
101 |
-
except:
|
102 |
print(">> Failed to load custom CUDA kernel for BigVGAN. Falling back to torch.")
|
|
|
103 |
self.use_cuda_kernel = False
|
104 |
|
105 |
self.extract_features = SeamlessM4TFeatureExtractor.from_pretrained("facebook/w2v-bert-2.0")
|
@@ -143,7 +149,7 @@ class IndexTTS2:
|
|
143 |
print(">> campplus_model weights restored from:", campplus_ckpt_path)
|
144 |
|
145 |
bigvgan_name = self.cfg.vocoder.name
|
146 |
-
self.bigvgan = bigvgan.BigVGAN.from_pretrained(bigvgan_name, use_cuda_kernel=
|
147 |
self.bigvgan = self.bigvgan.to(self.device)
|
148 |
self.bigvgan.remove_weight_norm()
|
149 |
self.bigvgan.eval()
|
@@ -261,7 +267,7 @@ class IndexTTS2:
|
|
261 |
|
262 |
def insert_interval_silence(self, wavs, sampling_rate=22050, interval_silence=200):
|
263 |
"""
|
264 |
-
Insert silences between
|
265 |
wavs: List[torch.tensor]
|
266 |
"""
|
267 |
|
@@ -286,47 +292,69 @@ class IndexTTS2:
|
|
286 |
if self.gr_progress is not None:
|
287 |
self.gr_progress(value, desc=desc)
|
288 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
289 |
# 原始推理模式
|
290 |
def infer(self, spk_audio_prompt, text, output_path,
|
291 |
emo_audio_prompt=None, emo_alpha=1.0,
|
292 |
emo_vector=None,
|
293 |
use_emo_text=False, emo_text=None, use_random=False, interval_silence=200,
|
294 |
-
verbose=False,
|
295 |
-
print(">>
|
296 |
-
self._set_gr_progress(0, "
|
297 |
if verbose:
|
298 |
-
print(f"origin text:{text}, spk_audio_prompt:{spk_audio_prompt},"
|
299 |
-
f"
|
300 |
f"emo_vector:{emo_vector}, use_emo_text:{use_emo_text}, "
|
301 |
f"emo_text:{emo_text}")
|
302 |
start_time = time.perf_counter()
|
303 |
|
304 |
-
if use_emo_text:
|
|
|
|
|
305 |
emo_audio_prompt = None
|
306 |
-
|
307 |
-
|
308 |
-
#
|
309 |
if emo_text is None:
|
310 |
-
emo_text = text
|
311 |
-
emo_dict
|
312 |
-
print(emo_dict)
|
|
|
313 |
emo_vector = list(emo_dict.values())
|
314 |
|
315 |
if emo_vector is not None:
|
316 |
-
|
317 |
-
|
318 |
-
#
|
319 |
-
|
|
|
|
|
|
|
|
|
320 |
|
321 |
if emo_audio_prompt is None:
|
|
|
|
|
322 |
emo_audio_prompt = spk_audio_prompt
|
|
|
323 |
emo_alpha = 1.0
|
324 |
-
# assert emo_alpha == 1.0
|
325 |
|
326 |
# 如果参考音频改变了,才需要重新生成, 提升速度
|
327 |
if self.cache_spk_cond is None or self.cache_spk_audio_prompt != spk_audio_prompt:
|
328 |
-
audio,
|
329 |
-
audio = torch.tensor(audio).unsqueeze(0)
|
330 |
audio_22k = torchaudio.transforms.Resample(sr, 22050)(audio)
|
331 |
audio_16k = torchaudio.transforms.Resample(sr, 16000)(audio)
|
332 |
|
@@ -377,7 +405,7 @@ class IndexTTS2:
|
|
377 |
emovec_mat = emovec_mat.unsqueeze(0)
|
378 |
|
379 |
if self.cache_emo_cond is None or self.cache_emo_audio_prompt != emo_audio_prompt:
|
380 |
-
emo_audio, _ =
|
381 |
emo_inputs = self.extract_features(emo_audio, sampling_rate=16000, return_tensors="pt")
|
382 |
emo_input_features = emo_inputs["input_features"]
|
383 |
emo_attention_mask = emo_inputs["attention_mask"]
|
@@ -392,12 +420,13 @@ class IndexTTS2:
|
|
392 |
|
393 |
self._set_gr_progress(0.1, "text processing...")
|
394 |
text_tokens_list = self.tokenizer.tokenize(text)
|
395 |
-
|
|
|
396 |
if verbose:
|
397 |
print("text_tokens_list:", text_tokens_list)
|
398 |
-
print("
|
399 |
-
print("
|
400 |
-
print(*
|
401 |
do_sample = generation_kwargs.pop("do_sample", True)
|
402 |
top_p = generation_kwargs.pop("top_p", 0.8)
|
403 |
top_k = generation_kwargs.pop("top_k", 30)
|
@@ -414,9 +443,11 @@ class IndexTTS2:
|
|
414 |
gpt_forward_time = 0
|
415 |
s2mel_time = 0
|
416 |
bigvgan_time = 0
|
417 |
-
progress = 0
|
418 |
has_warned = False
|
419 |
-
for sent in
|
|
|
|
|
|
|
420 |
text_tokens = self.tokenizer.convert_tokens_to_ids(sent)
|
421 |
text_tokens = torch.tensor(text_tokens, dtype=torch.int32, device=self.device).unsqueeze(0)
|
422 |
if verbose:
|
@@ -424,7 +455,7 @@ class IndexTTS2:
|
|
424 |
print(f"text_tokens shape: {text_tokens.shape}, text_tokens type: {text_tokens.dtype}")
|
425 |
# debug tokenizer
|
426 |
text_token_syms = self.tokenizer.convert_ids_to_tokens(text_tokens[0].tolist())
|
427 |
-
print("text_token_syms is same as
|
428 |
|
429 |
m_start_time = time.perf_counter()
|
430 |
with torch.no_grad():
|
@@ -465,7 +496,7 @@ class IndexTTS2:
|
|
465 |
warnings.warn(
|
466 |
f"WARN: generation stopped due to exceeding `max_mel_tokens` ({max_mel_tokens}). "
|
467 |
f"Input text tokens: {text_tokens.shape[1]}. "
|
468 |
-
f"Consider reducing `
|
469 |
category=RuntimeWarning
|
470 |
)
|
471 |
has_warned = True
|
@@ -546,7 +577,8 @@ class IndexTTS2:
|
|
546 |
# wavs.append(wav[:, :-512])
|
547 |
wavs.append(wav.cpu()) # to cpu before saving
|
548 |
end_time = time.perf_counter()
|
549 |
-
|
|
|
550 |
wavs = self.insert_interval_silence(wavs, sampling_rate=sampling_rate, interval_silence=interval_silence)
|
551 |
wav = torch.cat(wavs, dim=1)
|
552 |
wav_length = wav.shape[-1] / sampling_rate
|
@@ -595,59 +627,52 @@ class QwenEmotion:
|
|
595 |
device_map="auto"
|
596 |
)
|
597 |
self.prompt = "文本情感分类"
|
598 |
-
self.
|
599 |
-
"愤怒": "angry",
|
600 |
"高兴": "happy",
|
601 |
-
"
|
602 |
-
"反感": "hate",
|
603 |
"悲伤": "sad",
|
604 |
-
"
|
605 |
-
"
|
606 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
607 |
}
|
608 |
-
self.backup_dict = {"happy": 0, "angry": 0, "sad": 0, "fear": 0, "hate": 0, "low": 0, "surprise": 0,
|
609 |
-
"neutral": 1.0}
|
610 |
self.max_score = 1.2
|
611 |
self.min_score = 0.0
|
612 |
|
|
|
|
|
|
|
613 |
def convert(self, content):
|
614 |
-
|
615 |
-
|
616 |
-
|
617 |
-
|
618 |
-
|
619 |
-
|
620 |
-
|
621 |
-
|
622 |
-
|
623 |
-
|
624 |
-
|
625 |
-
|
626 |
-
|
627 |
-
|
628 |
-
ordered_parts = [parts_dict[key] for key in desired_order if key in parts_dict]
|
629 |
-
parts = ordered_parts
|
630 |
-
if len(parts) != len(self.convert_dict):
|
631 |
-
return self.backup_dict
|
632 |
-
|
633 |
-
emotion_dict = {}
|
634 |
-
for part in parts:
|
635 |
-
key_value = part.strip().split(':')
|
636 |
-
if len(key_value) == 2:
|
637 |
-
try:
|
638 |
-
key = self.convert_dict[key_value[0].strip()]
|
639 |
-
value = float(key_value[1].strip())
|
640 |
-
value = max(self.min_score, min(self.max_score, value))
|
641 |
-
emotion_dict[key] = value
|
642 |
-
except Exception:
|
643 |
-
continue
|
644 |
-
|
645 |
-
for key in self.backup_dict:
|
646 |
-
if key not in emotion_dict:
|
647 |
-
emotion_dict[key] = 0.0
|
648 |
-
|
649 |
-
if sum(emotion_dict.values()) <= 0:
|
650 |
-
return self.backup_dict
|
651 |
|
652 |
return emotion_dict
|
653 |
|
@@ -680,9 +705,30 @@ class QwenEmotion:
|
|
680 |
except ValueError:
|
681 |
index = 0
|
682 |
|
683 |
-
content = self.tokenizer.decode(output_ids[index:], skip_special_tokens=True)
|
684 |
-
|
685 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
686 |
|
687 |
|
688 |
if __name__ == "__main__":
|
|
|
1 |
import os
|
2 |
from subprocess import CalledProcessError
|
3 |
|
4 |
+
os.environ['HF_HUB_CACHE'] = './checkpoints/hf_cache'
|
5 |
+
import json
|
6 |
+
import re
|
7 |
import time
|
8 |
import librosa
|
9 |
import torch
|
|
|
37 |
|
38 |
class IndexTTS2:
|
39 |
def __init__(
|
40 |
+
self, cfg_path="checkpoints/config.yaml", model_dir="checkpoints", use_fp16=False, device=None,
|
41 |
use_cuda_kernel=None,use_deepspeed=False
|
42 |
):
|
43 |
"""
|
44 |
Args:
|
45 |
cfg_path (str): path to the config file.
|
46 |
model_dir (str): path to the model directory.
|
47 |
+
use_fp16 (bool): whether to use fp16.
|
48 |
device (str): device to use (e.g., 'cuda:0', 'cpu'). If None, it will be set automatically based on the availability of CUDA or MPS.
|
49 |
use_cuda_kernel (None | bool): whether to use BigVGan custom fused activation CUDA kernel, only for CUDA device.
|
50 |
+
use_deepspeed (bool): whether to use DeepSpeed or not.
|
51 |
"""
|
52 |
if device is not None:
|
53 |
self.device = device
|
54 |
+
self.use_fp16 = False if device == "cpu" else use_fp16
|
55 |
self.use_cuda_kernel = use_cuda_kernel is not None and use_cuda_kernel and device.startswith("cuda")
|
56 |
elif torch.cuda.is_available():
|
57 |
self.device = "cuda:0"
|
58 |
+
self.use_fp16 = use_fp16
|
59 |
self.use_cuda_kernel = use_cuda_kernel is None or use_cuda_kernel
|
60 |
+
elif hasattr(torch, "xpu") and torch.xpu.is_available():
|
61 |
+
self.device = "xpu"
|
62 |
+
self.use_fp16 = use_fp16
|
63 |
+
self.use_cuda_kernel = False
|
64 |
elif hasattr(torch, "mps") and torch.backends.mps.is_available():
|
65 |
self.device = "mps"
|
66 |
+
self.use_fp16 = False # Use float16 on MPS is overhead than float32
|
67 |
self.use_cuda_kernel = False
|
68 |
else:
|
69 |
self.device = "cpu"
|
70 |
+
self.use_fp16 = False
|
71 |
self.use_cuda_kernel = False
|
72 |
print(">> Be patient, it may take a while to run in CPU mode.")
|
73 |
|
74 |
self.cfg = OmegaConf.load(cfg_path)
|
75 |
self.model_dir = model_dir
|
76 |
+
self.dtype = torch.float16 if self.use_fp16 else None
|
77 |
self.stop_mel_token = self.cfg.gpt.stop_mel_token
|
78 |
|
79 |
self.qwen_emo = QwenEmotion(os.path.join(self.model_dir, self.cfg.qwen_emo_path))
|
|
|
82 |
self.gpt_path = os.path.join(self.model_dir, self.cfg.gpt_checkpoint)
|
83 |
load_checkpoint(self.gpt, self.gpt_path)
|
84 |
self.gpt = self.gpt.to(self.device)
|
85 |
+
if self.use_fp16:
|
86 |
self.gpt.eval().half()
|
87 |
else:
|
88 |
self.gpt.eval()
|
89 |
print(">> GPT weights restored from:", self.gpt_path)
|
90 |
+
|
91 |
+
if use_deepspeed:
|
92 |
try:
|
93 |
import deepspeed
|
|
|
94 |
except (ImportError, OSError, CalledProcessError) as e:
|
95 |
use_deepspeed = False
|
96 |
+
print(f">> Failed to load DeepSpeed. Falling back to normal inference. Error: {e}")
|
97 |
|
98 |
+
self.gpt.post_init_gpt2_config(use_deepspeed=use_deepspeed, kv_cache=True, half=self.use_fp16)
|
|
|
|
|
99 |
|
100 |
if self.use_cuda_kernel:
|
101 |
# preload the CUDA kernel for BigVGAN
|
102 |
try:
|
103 |
+
from indextts.s2mel.modules.bigvgan.alias_free_activation.cuda import activation1d
|
104 |
|
105 |
+
print(">> Preload custom CUDA kernel for BigVGAN", activation1d.anti_alias_activation_cuda)
|
106 |
+
except Exception as e:
|
|
|
107 |
print(">> Failed to load custom CUDA kernel for BigVGAN. Falling back to torch.")
|
108 |
+
print(f"{e!r}")
|
109 |
self.use_cuda_kernel = False
|
110 |
|
111 |
self.extract_features = SeamlessM4TFeatureExtractor.from_pretrained("facebook/w2v-bert-2.0")
|
|
|
149 |
print(">> campplus_model weights restored from:", campplus_ckpt_path)
|
150 |
|
151 |
bigvgan_name = self.cfg.vocoder.name
|
152 |
+
self.bigvgan = bigvgan.BigVGAN.from_pretrained(bigvgan_name, use_cuda_kernel=self.use_cuda_kernel)
|
153 |
self.bigvgan = self.bigvgan.to(self.device)
|
154 |
self.bigvgan.remove_weight_norm()
|
155 |
self.bigvgan.eval()
|
|
|
267 |
|
268 |
def insert_interval_silence(self, wavs, sampling_rate=22050, interval_silence=200):
|
269 |
"""
|
270 |
+
Insert silences between generated segments.
|
271 |
wavs: List[torch.tensor]
|
272 |
"""
|
273 |
|
|
|
292 |
if self.gr_progress is not None:
|
293 |
self.gr_progress(value, desc=desc)
|
294 |
|
295 |
+
def _load_and_cut_audio(self,audio_path,max_audio_length_seconds,verbose=False,sr=None):
|
296 |
+
if not sr:
|
297 |
+
audio, sr = librosa.load(audio_path)
|
298 |
+
else:
|
299 |
+
audio, _ = librosa.load(audio_path,sr=sr)
|
300 |
+
audio = torch.tensor(audio).unsqueeze(0)
|
301 |
+
max_audio_samples = int(max_audio_length_seconds * sr)
|
302 |
+
|
303 |
+
if audio.shape[1] > max_audio_samples:
|
304 |
+
if verbose:
|
305 |
+
print(f"Audio too long ({audio.shape[1]} samples), truncating to {max_audio_samples} samples")
|
306 |
+
audio = audio[:, :max_audio_samples]
|
307 |
+
return audio, sr
|
308 |
+
|
309 |
# 原始推理模式
|
310 |
def infer(self, spk_audio_prompt, text, output_path,
|
311 |
emo_audio_prompt=None, emo_alpha=1.0,
|
312 |
emo_vector=None,
|
313 |
use_emo_text=False, emo_text=None, use_random=False, interval_silence=200,
|
314 |
+
verbose=False, max_text_tokens_per_segment=120, **generation_kwargs):
|
315 |
+
print(">> starting inference...")
|
316 |
+
self._set_gr_progress(0, "starting inference...")
|
317 |
if verbose:
|
318 |
+
print(f"origin text:{text}, spk_audio_prompt:{spk_audio_prompt}, "
|
319 |
+
f"emo_audio_prompt:{emo_audio_prompt}, emo_alpha:{emo_alpha}, "
|
320 |
f"emo_vector:{emo_vector}, use_emo_text:{use_emo_text}, "
|
321 |
f"emo_text:{emo_text}")
|
322 |
start_time = time.perf_counter()
|
323 |
|
324 |
+
if use_emo_text or emo_vector is not None:
|
325 |
+
# we're using a text or emotion vector guidance; so we must remove
|
326 |
+
# "emotion reference voice", to ensure we use correct emotion mixing!
|
327 |
emo_audio_prompt = None
|
328 |
+
|
329 |
+
if use_emo_text:
|
330 |
+
# automatically generate emotion vectors from text prompt
|
331 |
if emo_text is None:
|
332 |
+
emo_text = text # use main text prompt
|
333 |
+
emo_dict = self.qwen_emo.inference(emo_text)
|
334 |
+
print(f"detected emotion vectors from text: {emo_dict}")
|
335 |
+
# convert ordered dict to list of vectors; the order is VERY important!
|
336 |
emo_vector = list(emo_dict.values())
|
337 |
|
338 |
if emo_vector is not None:
|
339 |
+
# we have emotion vectors; they can't be blended via alpha mixing
|
340 |
+
# in the main inference process later, so we must pre-calculate
|
341 |
+
# their new strengths here based on the alpha instead!
|
342 |
+
emo_vector_scale = max(0.0, min(1.0, emo_alpha))
|
343 |
+
if emo_vector_scale != 1.0:
|
344 |
+
# scale each vector and truncate to 4 decimals (for nicer printing)
|
345 |
+
emo_vector = [int(x * emo_vector_scale * 10000) / 10000 for x in emo_vector]
|
346 |
+
print(f"scaled emotion vectors to {emo_vector_scale}x: {emo_vector}")
|
347 |
|
348 |
if emo_audio_prompt is None:
|
349 |
+
# we are not using any external "emotion reference voice"; use
|
350 |
+
# speaker's voice as the main emotion reference audio.
|
351 |
emo_audio_prompt = spk_audio_prompt
|
352 |
+
# must always use alpha=1.0 when we don't have an external reference voice
|
353 |
emo_alpha = 1.0
|
|
|
354 |
|
355 |
# 如果参考音频改变了,才需要重新生成, 提升速度
|
356 |
if self.cache_spk_cond is None or self.cache_spk_audio_prompt != spk_audio_prompt:
|
357 |
+
audio,sr = self._load_and_cut_audio(spk_audio_prompt,15,verbose)
|
|
|
358 |
audio_22k = torchaudio.transforms.Resample(sr, 22050)(audio)
|
359 |
audio_16k = torchaudio.transforms.Resample(sr, 16000)(audio)
|
360 |
|
|
|
405 |
emovec_mat = emovec_mat.unsqueeze(0)
|
406 |
|
407 |
if self.cache_emo_cond is None or self.cache_emo_audio_prompt != emo_audio_prompt:
|
408 |
+
emo_audio, _ = self._load_and_cut_audio(emo_audio_prompt,15,verbose,sr=16000)
|
409 |
emo_inputs = self.extract_features(emo_audio, sampling_rate=16000, return_tensors="pt")
|
410 |
emo_input_features = emo_inputs["input_features"]
|
411 |
emo_attention_mask = emo_inputs["attention_mask"]
|
|
|
420 |
|
421 |
self._set_gr_progress(0.1, "text processing...")
|
422 |
text_tokens_list = self.tokenizer.tokenize(text)
|
423 |
+
segments = self.tokenizer.split_segments(text_tokens_list, max_text_tokens_per_segment)
|
424 |
+
segments_count = len(segments)
|
425 |
if verbose:
|
426 |
print("text_tokens_list:", text_tokens_list)
|
427 |
+
print("segments count:", segments_count)
|
428 |
+
print("max_text_tokens_per_segment:", max_text_tokens_per_segment)
|
429 |
+
print(*segments, sep="\n")
|
430 |
do_sample = generation_kwargs.pop("do_sample", True)
|
431 |
top_p = generation_kwargs.pop("top_p", 0.8)
|
432 |
top_k = generation_kwargs.pop("top_k", 30)
|
|
|
443 |
gpt_forward_time = 0
|
444 |
s2mel_time = 0
|
445 |
bigvgan_time = 0
|
|
|
446 |
has_warned = False
|
447 |
+
for seg_idx, sent in enumerate(segments):
|
448 |
+
self._set_gr_progress(0.2 + 0.7 * seg_idx / segments_count,
|
449 |
+
f"speech synthesis {seg_idx + 1}/{segments_count}...")
|
450 |
+
|
451 |
text_tokens = self.tokenizer.convert_tokens_to_ids(sent)
|
452 |
text_tokens = torch.tensor(text_tokens, dtype=torch.int32, device=self.device).unsqueeze(0)
|
453 |
if verbose:
|
|
|
455 |
print(f"text_tokens shape: {text_tokens.shape}, text_tokens type: {text_tokens.dtype}")
|
456 |
# debug tokenizer
|
457 |
text_token_syms = self.tokenizer.convert_ids_to_tokens(text_tokens[0].tolist())
|
458 |
+
print("text_token_syms is same as segment tokens", text_token_syms == sent)
|
459 |
|
460 |
m_start_time = time.perf_counter()
|
461 |
with torch.no_grad():
|
|
|
496 |
warnings.warn(
|
497 |
f"WARN: generation stopped due to exceeding `max_mel_tokens` ({max_mel_tokens}). "
|
498 |
f"Input text tokens: {text_tokens.shape[1]}. "
|
499 |
+
f"Consider reducing `max_text_tokens_per_segment`({max_text_tokens_per_segment}) or increasing `max_mel_tokens`.",
|
500 |
category=RuntimeWarning
|
501 |
)
|
502 |
has_warned = True
|
|
|
577 |
# wavs.append(wav[:, :-512])
|
578 |
wavs.append(wav.cpu()) # to cpu before saving
|
579 |
end_time = time.perf_counter()
|
580 |
+
|
581 |
+
self._set_gr_progress(0.9, "saving audio...")
|
582 |
wavs = self.insert_interval_silence(wavs, sampling_rate=sampling_rate, interval_silence=interval_silence)
|
583 |
wav = torch.cat(wavs, dim=1)
|
584 |
wav_length = wav.shape[-1] / sampling_rate
|
|
|
627 |
device_map="auto"
|
628 |
)
|
629 |
self.prompt = "文本情感分类"
|
630 |
+
self.cn_key_to_en = {
|
|
|
631 |
"高兴": "happy",
|
632 |
+
"愤怒": "angry",
|
|
|
633 |
"悲伤": "sad",
|
634 |
+
"恐惧": "afraid",
|
635 |
+
"反感": "disgusted",
|
636 |
+
# TODO: the "低落" (melancholic) emotion will always be mapped to
|
637 |
+
# "悲伤" (sad) by QwenEmotion's text analysis. it doesn't know the
|
638 |
+
# difference between those emotions even if user writes exact words.
|
639 |
+
# SEE: `self.melancholic_words` for current workaround.
|
640 |
+
"低落": "melancholic",
|
641 |
+
"惊讶": "surprised",
|
642 |
+
"自然": "calm",
|
643 |
+
}
|
644 |
+
self.desired_vector_order = ["高兴", "愤怒", "悲伤", "恐惧", "反感", "低落", "惊讶", "自然"]
|
645 |
+
self.melancholic_words = {
|
646 |
+
# emotion text phrases that will force QwenEmotion's "悲伤" (sad) detection
|
647 |
+
# to become "低落" (melancholic) instead, to fix limitations mentioned above.
|
648 |
+
"低落",
|
649 |
+
"melancholy",
|
650 |
+
"melancholic",
|
651 |
+
"depression",
|
652 |
+
"depressed",
|
653 |
+
"gloomy",
|
654 |
}
|
|
|
|
|
655 |
self.max_score = 1.2
|
656 |
self.min_score = 0.0
|
657 |
|
658 |
+
def clamp_score(self, value):
|
659 |
+
return max(self.min_score, min(self.max_score, value))
|
660 |
+
|
661 |
def convert(self, content):
|
662 |
+
# generate emotion vector dictionary:
|
663 |
+
# - insert values in desired order (Python 3.7+ `dict` remembers insertion order)
|
664 |
+
# - convert Chinese keys to English
|
665 |
+
# - clamp all values to the allowed min/max range
|
666 |
+
# - use 0.0 for any values that were missing in `content`
|
667 |
+
emotion_dict = {
|
668 |
+
self.cn_key_to_en[cn_key]: self.clamp_score(content.get(cn_key, 0.0))
|
669 |
+
for cn_key in self.desired_vector_order
|
670 |
+
}
|
671 |
+
|
672 |
+
# default to a calm/neutral voice if all emotion vectors were empty
|
673 |
+
if all(val <= 0.0 for val in emotion_dict.values()):
|
674 |
+
print(">> no emotions detected; using default calm/neutral voice")
|
675 |
+
emotion_dict["calm"] = 1.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
676 |
|
677 |
return emotion_dict
|
678 |
|
|
|
705 |
except ValueError:
|
706 |
index = 0
|
707 |
|
708 |
+
content = self.tokenizer.decode(output_ids[index:], skip_special_tokens=True)
|
709 |
+
|
710 |
+
# decode the JSON emotion detections as a dictionary
|
711 |
+
try:
|
712 |
+
content = json.loads(content)
|
713 |
+
except json.decoder.JSONDecodeError:
|
714 |
+
# invalid JSON; fallback to manual string parsing
|
715 |
+
# print(">> parsing QwenEmotion response", content)
|
716 |
+
content = {
|
717 |
+
m.group(1): float(m.group(2))
|
718 |
+
for m in re.finditer(r'([^\s":.,]+?)"?\s*:\s*([\d.]+)', content)
|
719 |
+
}
|
720 |
+
# print(">> dict result", content)
|
721 |
+
|
722 |
+
# workaround for QwenEmotion's inability to distinguish "悲伤" (sad) vs "低落" (melancholic).
|
723 |
+
# if we detect any of the IndexTTS "melancholic" words, we swap those vectors
|
724 |
+
# to encode the "sad" emotion as "melancholic" (instead of sadness).
|
725 |
+
text_input_lower = text_input.lower()
|
726 |
+
if any(word in text_input_lower for word in self.melancholic_words):
|
727 |
+
# print(">> before vec swap", content)
|
728 |
+
content["悲伤"], content["低落"] = content.get("低落", 0.0), content.get("悲伤", 0.0)
|
729 |
+
# print(">> after vec swap", content)
|
730 |
+
|
731 |
+
return self.convert(content)
|
732 |
|
733 |
|
734 |
if __name__ == "__main__":
|
indextts/s2mel/modules/openvoice/api.py
CHANGED
@@ -63,9 +63,9 @@ class BaseSpeakerTTS(OpenVoiceBaseClass):
|
|
63 |
return audio_segments
|
64 |
|
65 |
@staticmethod
|
66 |
-
def
|
67 |
-
texts = utils.
|
68 |
-
print(" > Text
|
69 |
print('\n'.join(texts))
|
70 |
print(" > ===========================")
|
71 |
return texts
|
@@ -74,7 +74,7 @@ class BaseSpeakerTTS(OpenVoiceBaseClass):
|
|
74 |
mark = self.language_marks.get(language.lower(), None)
|
75 |
assert mark is not None, f"language {language} is not supported"
|
76 |
|
77 |
-
texts = self.
|
78 |
|
79 |
audio_list = []
|
80 |
for t in texts:
|
|
|
63 |
return audio_segments
|
64 |
|
65 |
@staticmethod
|
66 |
+
def split_segments_into_pieces(text, language_str):
|
67 |
+
texts = utils.split_segment(text, language_str=language_str)
|
68 |
+
print(" > Text split into segments.")
|
69 |
print('\n'.join(texts))
|
70 |
print(" > ===========================")
|
71 |
return texts
|
|
|
74 |
mark = self.language_marks.get(language.lower(), None)
|
75 |
assert mark is not None, f"language {language} is not supported"
|
76 |
|
77 |
+
texts = self.split_segments_into_pieces(text, mark)
|
78 |
|
79 |
audio_list = []
|
80 |
for t in texts:
|
indextts/s2mel/modules/openvoice/openvoice_app.py
CHANGED
@@ -233,7 +233,7 @@ with gr.Blocks(analytics_enabled=False) as demo:
|
|
233 |
with gr.Column():
|
234 |
input_text_gr = gr.Textbox(
|
235 |
label="Text Prompt",
|
236 |
-
info="One or two sentences at a time
|
237 |
value="He hoped there would be stew for dinner, turnips and carrots and bruised potatoes and fat mutton pieces to be ladled out in thick, peppered, flour-fattened sauce.",
|
238 |
)
|
239 |
style_gr = gr.Dropdown(
|
|
|
233 |
with gr.Column():
|
234 |
input_text_gr = gr.Textbox(
|
235 |
label="Text Prompt",
|
236 |
+
info="One or two sentences at a time produces the best results. Up to 200 text characters.",
|
237 |
value="He hoped there would be stew for dinner, turnips and carrots and bruised potatoes and fat mutton pieces to be ladled out in thick, peppered, flour-fattened sauce.",
|
238 |
)
|
239 |
style_gr = gr.Dropdown(
|
indextts/s2mel/modules/openvoice/utils.py
CHANGED
@@ -75,23 +75,23 @@ def bits_to_string(bits_array):
|
|
75 |
return output_string
|
76 |
|
77 |
|
78 |
-
def
|
79 |
if language_str in ['EN']:
|
80 |
-
|
81 |
else:
|
82 |
-
|
83 |
-
return
|
84 |
|
85 |
-
def
|
86 |
-
"""Split Long sentences into list of short
|
87 |
|
88 |
Args:
|
89 |
str: Input sentences.
|
90 |
|
91 |
Returns:
|
92 |
-
List[str]: list of output
|
93 |
"""
|
94 |
-
# deal with dirty
|
95 |
text = re.sub('[。!?;]', '.', text)
|
96 |
text = re.sub('[,]', ',', text)
|
97 |
text = re.sub('[“”]', '"', text)
|
@@ -100,36 +100,36 @@ def split_sentences_latin(text, min_len=10):
|
|
100 |
text = re.sub('[\n\t ]+', ' ', text)
|
101 |
text = re.sub('([,.!?;])', r'\1 $#!', text)
|
102 |
# split
|
103 |
-
|
104 |
-
if len(
|
105 |
|
106 |
-
|
107 |
new_sent = []
|
108 |
count_len = 0
|
109 |
-
for ind, sent in enumerate(
|
110 |
# print(sent)
|
111 |
new_sent.append(sent)
|
112 |
count_len += len(sent.split(" "))
|
113 |
-
if count_len > min_len or ind == len(
|
114 |
count_len = 0
|
115 |
-
|
116 |
new_sent = []
|
117 |
-
return
|
118 |
|
119 |
|
120 |
-
def
|
121 |
-
"""Avoid short
|
122 |
|
123 |
Args:
|
124 |
-
List[str]: list of input
|
125 |
|
126 |
Returns:
|
127 |
-
List[str]: list of output
|
128 |
"""
|
129 |
sens_out = []
|
130 |
for s in sens:
|
131 |
-
# If the previous
|
132 |
-
# the current
|
133 |
if len(sens_out) > 0 and len(sens_out[-1].split(" ")) <= 2:
|
134 |
sens_out[-1] = sens_out[-1] + " " + s
|
135 |
else:
|
@@ -142,7 +142,7 @@ def merge_short_sentences_latin(sens):
|
|
142 |
pass
|
143 |
return sens_out
|
144 |
|
145 |
-
def
|
146 |
text = re.sub('[。!?;]', '.', text)
|
147 |
text = re.sub('[,]', ',', text)
|
148 |
# 将文本中的换行符、空格和制表符替换为空格
|
@@ -150,37 +150,37 @@ def split_sentences_zh(text, min_len=10):
|
|
150 |
# 在标点符号后添加一个空格
|
151 |
text = re.sub('([,.!?;])', r'\1 $#!', text)
|
152 |
# 分隔句子并去除前后空格
|
153 |
-
#
|
154 |
-
|
155 |
-
if len(
|
156 |
|
157 |
-
|
158 |
new_sent = []
|
159 |
count_len = 0
|
160 |
-
for ind, sent in enumerate(
|
161 |
new_sent.append(sent)
|
162 |
count_len += len(sent)
|
163 |
-
if count_len > min_len or ind == len(
|
164 |
count_len = 0
|
165 |
-
|
166 |
new_sent = []
|
167 |
-
return
|
168 |
|
169 |
|
170 |
-
def
|
171 |
# return sens
|
172 |
-
"""Avoid short
|
173 |
|
174 |
Args:
|
175 |
-
List[str]: list of input
|
176 |
|
177 |
Returns:
|
178 |
-
List[str]: list of output
|
179 |
"""
|
180 |
sens_out = []
|
181 |
for s in sens:
|
182 |
# If the previous sentense is too short, merge them with
|
183 |
-
# the current
|
184 |
if len(sens_out) > 0 and len(sens_out[-1]) <= 2:
|
185 |
sens_out[-1] = sens_out[-1] + " " + s
|
186 |
else:
|
|
|
75 |
return output_string
|
76 |
|
77 |
|
78 |
+
def split_segment(text, min_len=10, language_str='[EN]'):
|
79 |
if language_str in ['EN']:
|
80 |
+
segments = split_segments_latin(text, min_len=min_len)
|
81 |
else:
|
82 |
+
segments = split_segments_zh(text, min_len=min_len)
|
83 |
+
return segments
|
84 |
|
85 |
+
def split_segments_latin(text, min_len=10):
|
86 |
+
"""Split Long sentences into list of short segments.
|
87 |
|
88 |
Args:
|
89 |
str: Input sentences.
|
90 |
|
91 |
Returns:
|
92 |
+
List[str]: list of output segments.
|
93 |
"""
|
94 |
+
# deal with dirty text characters
|
95 |
text = re.sub('[。!?;]', '.', text)
|
96 |
text = re.sub('[,]', ',', text)
|
97 |
text = re.sub('[“”]', '"', text)
|
|
|
100 |
text = re.sub('[\n\t ]+', ' ', text)
|
101 |
text = re.sub('([,.!?;])', r'\1 $#!', text)
|
102 |
# split
|
103 |
+
segments = [s.strip() for s in text.split('$#!')]
|
104 |
+
if len(segments[-1]) == 0: del segments[-1]
|
105 |
|
106 |
+
new_segments = []
|
107 |
new_sent = []
|
108 |
count_len = 0
|
109 |
+
for ind, sent in enumerate(segments):
|
110 |
# print(sent)
|
111 |
new_sent.append(sent)
|
112 |
count_len += len(sent.split(" "))
|
113 |
+
if count_len > min_len or ind == len(segments) - 1:
|
114 |
count_len = 0
|
115 |
+
new_segments.append(' '.join(new_sent))
|
116 |
new_sent = []
|
117 |
+
return merge_short_segments_latin(new_segments)
|
118 |
|
119 |
|
120 |
+
def merge_short_segments_latin(sens):
|
121 |
+
"""Avoid short segments by merging them with the following segment.
|
122 |
|
123 |
Args:
|
124 |
+
List[str]: list of input segments.
|
125 |
|
126 |
Returns:
|
127 |
+
List[str]: list of output segments.
|
128 |
"""
|
129 |
sens_out = []
|
130 |
for s in sens:
|
131 |
+
# If the previous segment is too short, merge them with
|
132 |
+
# the current segment.
|
133 |
if len(sens_out) > 0 and len(sens_out[-1].split(" ")) <= 2:
|
134 |
sens_out[-1] = sens_out[-1] + " " + s
|
135 |
else:
|
|
|
142 |
pass
|
143 |
return sens_out
|
144 |
|
145 |
+
def split_segments_zh(text, min_len=10):
|
146 |
text = re.sub('[。!?;]', '.', text)
|
147 |
text = re.sub('[,]', ',', text)
|
148 |
# 将文本中的换行符、空格和制表符替换为空格
|
|
|
150 |
# 在标点符号后添加一个空格
|
151 |
text = re.sub('([,.!?;])', r'\1 $#!', text)
|
152 |
# 分隔句子并去除前后空格
|
153 |
+
# segments = [s.strip() for s in re.split('(。|!|?|;)', text)]
|
154 |
+
segments = [s.strip() for s in text.split('$#!')]
|
155 |
+
if len(segments[-1]) == 0: del segments[-1]
|
156 |
|
157 |
+
new_segments = []
|
158 |
new_sent = []
|
159 |
count_len = 0
|
160 |
+
for ind, sent in enumerate(segments):
|
161 |
new_sent.append(sent)
|
162 |
count_len += len(sent)
|
163 |
+
if count_len > min_len or ind == len(segments) - 1:
|
164 |
count_len = 0
|
165 |
+
new_segments.append(' '.join(new_sent))
|
166 |
new_sent = []
|
167 |
+
return merge_short_segments_zh(new_segments)
|
168 |
|
169 |
|
170 |
+
def merge_short_segments_zh(sens):
|
171 |
# return sens
|
172 |
+
"""Avoid short segments by merging them with the following segment.
|
173 |
|
174 |
Args:
|
175 |
+
List[str]: list of input segments.
|
176 |
|
177 |
Returns:
|
178 |
+
List[str]: list of output segments.
|
179 |
"""
|
180 |
sens_out = []
|
181 |
for s in sens:
|
182 |
# If the previous sentense is too short, merge them with
|
183 |
+
# the current segment.
|
184 |
if len(sens_out) > 0 and len(sens_out[-1]) <= 2:
|
185 |
sens_out[-1] = sens_out[-1] + " " + s
|
186 |
else:
|
indextts/utils/front.py
CHANGED
@@ -91,7 +91,7 @@ class TextNormalizer:
|
|
91 |
import platform
|
92 |
if self.zh_normalizer is not None and self.en_normalizer is not None:
|
93 |
return
|
94 |
-
if platform.system()
|
95 |
from wetext import Normalizer
|
96 |
|
97 |
self.zh_normalizer = Normalizer(remove_erhua=False, lang="zh", operator="tn")
|
@@ -342,8 +342,8 @@ class TextTokenizer:
|
|
342 |
return de_tokenized_by_CJK_char(decoded, do_lower_case=do_lower_case)
|
343 |
|
344 |
@staticmethod
|
345 |
-
def
|
346 |
-
tokenized_str: List[str], split_tokens: List[str],
|
347 |
) -> List[List[str]]:
|
348 |
"""
|
349 |
将tokenize后的结果按特定token进一步分割
|
@@ -351,67 +351,67 @@ class TextTokenizer:
|
|
351 |
# 处理特殊情况
|
352 |
if len(tokenized_str) == 0:
|
353 |
return []
|
354 |
-
|
355 |
-
|
356 |
-
|
357 |
for i in range(len(tokenized_str)):
|
358 |
token = tokenized_str[i]
|
359 |
-
|
360 |
-
|
361 |
-
if
|
362 |
-
if token in split_tokens and
|
363 |
if i < len(tokenized_str) - 1:
|
364 |
if tokenized_str[i + 1] in ["'", "▁'"]:
|
365 |
# 后续token是',则不切分
|
366 |
-
|
367 |
i += 1
|
368 |
-
|
369 |
-
|
370 |
-
|
371 |
continue
|
372 |
# 如果当前tokens的长度超过最大限制
|
373 |
-
if not ("," in split_tokens or "▁," in split_tokens ) and ("," in
|
374 |
# 如果当前tokens中有,,则按,分割
|
375 |
-
|
376 |
-
|
377 |
)
|
378 |
-
elif "-" not in split_tokens and "-" in
|
379 |
# 没有,,则按-分割
|
380 |
-
|
381 |
-
|
382 |
)
|
383 |
else:
|
384 |
# 按照长度分割
|
385 |
-
|
386 |
-
for j in range(0, len(
|
387 |
-
if j +
|
388 |
-
|
389 |
else:
|
390 |
-
|
391 |
warnings.warn(
|
392 |
-
f"The tokens length of
|
393 |
-
f"Tokens in
|
394 |
"Maybe unexpected behavior",
|
395 |
RuntimeWarning,
|
396 |
)
|
397 |
-
|
398 |
-
|
399 |
-
|
400 |
-
if
|
401 |
-
assert
|
402 |
-
|
403 |
# 如果相邻的句子加起来长度小于最大限制,则合并
|
404 |
-
|
405 |
-
for
|
406 |
-
if len(
|
407 |
continue
|
408 |
-
if len(
|
409 |
-
|
410 |
-
elif len(
|
411 |
-
|
412 |
else:
|
413 |
-
|
414 |
-
return
|
415 |
|
416 |
punctuation_marks_tokens = [
|
417 |
".",
|
@@ -422,9 +422,9 @@ class TextTokenizer:
|
|
422 |
"▁?",
|
423 |
"▁...", # ellipsis
|
424 |
]
|
425 |
-
def
|
426 |
-
return TextTokenizer.
|
427 |
-
tokenized, self.punctuation_marks_tokens,
|
428 |
)
|
429 |
|
430 |
|
@@ -516,19 +516,19 @@ if __name__ == "__main__":
|
|
516 |
# 测试 normalize后的字符能被分词器识别
|
517 |
print(f"`{ch}`", "->", tokenizer.sp_model.Encode(ch, out_type=str))
|
518 |
print(f"` {ch}`", "->", tokenizer.sp_model.Encode(f" {ch}", out_type=str))
|
519 |
-
|
520 |
for i in range(len(cases)):
|
521 |
print(f"原始文本: {cases[i]}")
|
522 |
print(f"Normalized: {text_normalizer.normalize(cases[i])}")
|
523 |
tokens = tokenizer.tokenize(cases[i])
|
524 |
print("Tokenzied: ", ", ".join([f"`{t}`" for t in tokens]))
|
525 |
-
|
526 |
-
print("
|
527 |
-
if len(
|
528 |
-
for j in range(len(
|
529 |
-
print(f" {j}, count:", len(
|
530 |
-
if len(
|
531 |
-
print(f"Warning:
|
532 |
#print(f"Token IDs (first 10): {codes[i][:10]}")
|
533 |
if tokenizer.unk_token in codes[i]:
|
534 |
print(f"Warning: `{cases[i]}` contains UNKNOWN token")
|
|
|
91 |
import platform
|
92 |
if self.zh_normalizer is not None and self.en_normalizer is not None:
|
93 |
return
|
94 |
+
if platform.system() != "Linux": # Mac and Windows
|
95 |
from wetext import Normalizer
|
96 |
|
97 |
self.zh_normalizer = Normalizer(remove_erhua=False, lang="zh", operator="tn")
|
|
|
342 |
return de_tokenized_by_CJK_char(decoded, do_lower_case=do_lower_case)
|
343 |
|
344 |
@staticmethod
|
345 |
+
def split_segments_by_token(
|
346 |
+
tokenized_str: List[str], split_tokens: List[str], max_text_tokens_per_segment: int
|
347 |
) -> List[List[str]]:
|
348 |
"""
|
349 |
将tokenize后的结果按特定token进一步分割
|
|
|
351 |
# 处理特殊情况
|
352 |
if len(tokenized_str) == 0:
|
353 |
return []
|
354 |
+
segments: List[List[str]] = []
|
355 |
+
current_segment = []
|
356 |
+
current_segment_tokens_len = 0
|
357 |
for i in range(len(tokenized_str)):
|
358 |
token = tokenized_str[i]
|
359 |
+
current_segment.append(token)
|
360 |
+
current_segment_tokens_len += 1
|
361 |
+
if current_segment_tokens_len <= max_text_tokens_per_segment:
|
362 |
+
if token in split_tokens and current_segment_tokens_len > 2:
|
363 |
if i < len(tokenized_str) - 1:
|
364 |
if tokenized_str[i + 1] in ["'", "▁'"]:
|
365 |
# 后续token是',则不切分
|
366 |
+
current_segment.append(tokenized_str[i + 1])
|
367 |
i += 1
|
368 |
+
segments.append(current_segment)
|
369 |
+
current_segment = []
|
370 |
+
current_segment_tokens_len = 0
|
371 |
continue
|
372 |
# 如果当前tokens的长度超过最大限制
|
373 |
+
if not ("," in split_tokens or "▁," in split_tokens ) and ("," in current_segment or "▁," in current_segment):
|
374 |
# 如果当前tokens中有,,则按,分割
|
375 |
+
sub_segments = TextTokenizer.split_segments_by_token(
|
376 |
+
current_segment, [",", "▁,"], max_text_tokens_per_segment=max_text_tokens_per_segment
|
377 |
)
|
378 |
+
elif "-" not in split_tokens and "-" in current_segment:
|
379 |
# 没有,,则按-分割
|
380 |
+
sub_segments = TextTokenizer.split_segments_by_token(
|
381 |
+
current_segment, ["-"], max_text_tokens_per_segment=max_text_tokens_per_segment
|
382 |
)
|
383 |
else:
|
384 |
# 按照长度分割
|
385 |
+
sub_segments = []
|
386 |
+
for j in range(0, len(current_segment), max_text_tokens_per_segment):
|
387 |
+
if j + max_text_tokens_per_segment < len(current_segment):
|
388 |
+
sub_segments.append(current_segment[j : j + max_text_tokens_per_segment])
|
389 |
else:
|
390 |
+
sub_segments.append(current_segment[j:])
|
391 |
warnings.warn(
|
392 |
+
f"The tokens length of segment exceeds limit: {max_text_tokens_per_segment}, "
|
393 |
+
f"Tokens in segment: {current_segment}."
|
394 |
"Maybe unexpected behavior",
|
395 |
RuntimeWarning,
|
396 |
)
|
397 |
+
segments.extend(sub_segments)
|
398 |
+
current_segment = []
|
399 |
+
current_segment_tokens_len = 0
|
400 |
+
if current_segment_tokens_len > 0:
|
401 |
+
assert current_segment_tokens_len <= max_text_tokens_per_segment
|
402 |
+
segments.append(current_segment)
|
403 |
# 如果相邻的句子加起来长度小于最大限制,则合并
|
404 |
+
merged_segments = []
|
405 |
+
for segment in segments:
|
406 |
+
if len(segment) == 0:
|
407 |
continue
|
408 |
+
if len(merged_segments) == 0:
|
409 |
+
merged_segments.append(segment)
|
410 |
+
elif len(merged_segments[-1]) + len(segment) <= max_text_tokens_per_segment:
|
411 |
+
merged_segments[-1] = merged_segments[-1] + segment
|
412 |
else:
|
413 |
+
merged_segments.append(segment)
|
414 |
+
return merged_segments
|
415 |
|
416 |
punctuation_marks_tokens = [
|
417 |
".",
|
|
|
422 |
"▁?",
|
423 |
"▁...", # ellipsis
|
424 |
]
|
425 |
+
def split_segments(self, tokenized: List[str], max_text_tokens_per_segment=120) -> List[List[str]]:
|
426 |
+
return TextTokenizer.split_segments_by_token(
|
427 |
+
tokenized, self.punctuation_marks_tokens, max_text_tokens_per_segment=max_text_tokens_per_segment
|
428 |
)
|
429 |
|
430 |
|
|
|
516 |
# 测试 normalize后的字符能被分词器识别
|
517 |
print(f"`{ch}`", "->", tokenizer.sp_model.Encode(ch, out_type=str))
|
518 |
print(f"` {ch}`", "->", tokenizer.sp_model.Encode(f" {ch}", out_type=str))
|
519 |
+
max_text_tokens_per_segment=120
|
520 |
for i in range(len(cases)):
|
521 |
print(f"原始文本: {cases[i]}")
|
522 |
print(f"Normalized: {text_normalizer.normalize(cases[i])}")
|
523 |
tokens = tokenizer.tokenize(cases[i])
|
524 |
print("Tokenzied: ", ", ".join([f"`{t}`" for t in tokens]))
|
525 |
+
segments = tokenizer.split_segments(tokens, max_text_tokens_per_segment=max_text_tokens_per_segment)
|
526 |
+
print("Segments count:", len(segments))
|
527 |
+
if len(segments) > 1:
|
528 |
+
for j in range(len(segments)):
|
529 |
+
print(f" {j}, count:", len(segments[j]), ", tokens:", "".join(segments[j]))
|
530 |
+
if len(segments[j]) > max_text_tokens_per_segment:
|
531 |
+
print(f"Warning: segment {j} is too long, length: {len(segments[j])}")
|
532 |
#print(f"Token IDs (first 10): {codes[i][:10]}")
|
533 |
if tokenizer.unk_token in codes[i]:
|
534 |
print(f"Warning: `{cases[i]}` contains UNKNOWN token")
|
indextts/utils/maskgct/models/codec/facodec/modules/JDC/bst.t7
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:54dc94364b97e18ac1dfa6287714ed121248cfaac4cfd39d061c6e0a089ef169
|
3 |
+
size 21029926
|
tools/i18n/locale/en_US.json
CHANGED
@@ -1,46 +1,49 @@
|
|
1 |
{
|
2 |
-
"
|
3 |
-
"如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录LICENSE.": "If you do not agree to these terms, you are not permitted to use or reference any code or files within the software package. For further details, please refer to the LICENSE
|
4 |
"时长必须为正数": "Duration must be a positive number",
|
5 |
"请输入有效的浮点数": "Please enter a valid floating-point number",
|
6 |
"使用情感参考音频": "Use emotion reference audio",
|
7 |
-
"使用情感向量控制": "Use emotion
|
8 |
"使用情感描述文本控制": "Use text description to control emotion",
|
9 |
"上传情感参考音频": "Upload emotion reference audio",
|
10 |
"情感权重": "Emotion control weight",
|
11 |
"喜": "Happy",
|
12 |
"怒": "Angry",
|
13 |
"哀": "Sad",
|
14 |
-
"惧": "
|
15 |
-
"厌恶": "
|
16 |
-
"低落": "
|
17 |
-
"惊喜": "
|
18 |
-
"平静": "
|
19 |
"情感描述文本": "Emotion description",
|
20 |
-
"
|
21 |
"高级生成参数设置": "Advanced generation parameter settings",
|
22 |
"情感向量之和不能超过1.5,请调整后重试。": "The sum of the emotion vectors cannot exceed 1.5. Please adjust and try again.",
|
23 |
-
"音色参考音频": "Voice
|
24 |
"音频生成": "Speech Synthesis",
|
25 |
"文本": "Text",
|
26 |
"生成语音": "Synthesize",
|
27 |
"生成结果": "Synthesis Result",
|
28 |
"功能设置": "Settings",
|
29 |
-
"分句设置": "
|
30 |
-
"参数会影响音频质量和生成速度": "
|
31 |
-
"分句最大Token数": "Max tokens per
|
32 |
-
"建议80~200之间,值越大,分句越长;值越小,分句越碎;过小过大都可能导致音频质量不高": "Recommended
|
33 |
-
"预览分句结果": "Preview
|
34 |
"序号": "Index",
|
35 |
"分句内容": "Content",
|
36 |
"Token数": "Token Count",
|
37 |
"情感控制方式": "Emotion control method",
|
38 |
"GPT2 采样设置": "GPT-2 Sampling Configuration",
|
39 |
-
"参数会影响音频多样性和生成速度详见": "
|
40 |
-
"
|
41 |
-
"
|
42 |
-
"
|
43 |
-
"
|
|
|
|
|
44 |
"与音色参考音频相同": "Same as the voice reference",
|
45 |
-
"情感随机采样": "
|
|
|
46 |
}
|
|
|
1 |
{
|
2 |
+
"本软件以自拟协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责.": "This software is open-sourced under customized license. The author has no control over the software, and users of the software, as well as those who distribute the audio generated by the software, assume full responsibility.",
|
3 |
+
"如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录LICENSE.": "If you do not agree to these terms, you are not permitted to use or reference any code or files within the software package. For further details, please refer to the LICENSE files in the root directory.",
|
4 |
"时长必须为正数": "Duration must be a positive number",
|
5 |
"请输入有效的浮点数": "Please enter a valid floating-point number",
|
6 |
"使用情感参考音频": "Use emotion reference audio",
|
7 |
+
"使用情感向量控制": "Use emotion vectors",
|
8 |
"使用情感描述文本控制": "Use text description to control emotion",
|
9 |
"上传情感参考音频": "Upload emotion reference audio",
|
10 |
"情感权重": "Emotion control weight",
|
11 |
"喜": "Happy",
|
12 |
"怒": "Angry",
|
13 |
"哀": "Sad",
|
14 |
+
"惧": "Afraid",
|
15 |
+
"厌恶": "Disgusted",
|
16 |
+
"低落": "Melancholic",
|
17 |
+
"惊喜": "Surprised",
|
18 |
+
"平静": "Calm",
|
19 |
"情感描述文本": "Emotion description",
|
20 |
+
"请输入情绪描述(或留空以自动使用目标文本作为情绪描述)": "Please input an emotion description (or leave blank to automatically use the main text prompt)",
|
21 |
"高级生成参数设置": "Advanced generation parameter settings",
|
22 |
"情感向量之和不能超过1.5,请调整后重试。": "The sum of the emotion vectors cannot exceed 1.5. Please adjust and try again.",
|
23 |
+
"音色参考音频": "Voice Reference",
|
24 |
"音频生成": "Speech Synthesis",
|
25 |
"文本": "Text",
|
26 |
"生成语音": "Synthesize",
|
27 |
"生成结果": "Synthesis Result",
|
28 |
"功能设置": "Settings",
|
29 |
+
"分句设置": "Text segmentation settings",
|
30 |
+
"参数会影响音频质量和生成速度": "These parameters affect the audio quality and generation speed.",
|
31 |
+
"分句最大Token数": "Max tokens per generation segment",
|
32 |
+
"建议80~200之间,值越大,分句越长;值越小,分句越碎;过小过大都可能导致音频质量不高": "Recommended range: 80 - 200. Larger values require more VRAM but improves the flow of the speech, while lower values require less VRAM but means more fragmented sentences. Values that are too small or too large may lead to less coherent speech.",
|
33 |
+
"预览分句结果": "Preview of the audio generation segments",
|
34 |
"序号": "Index",
|
35 |
"分句内容": "Content",
|
36 |
"Token数": "Token Count",
|
37 |
"情感控制方式": "Emotion control method",
|
38 |
"GPT2 采样设置": "GPT-2 Sampling Configuration",
|
39 |
+
"参数会影响音频多样性和生成速度详见": "Influences both the diversity of the generated audio and the generation speed. For further details, refer to",
|
40 |
+
"是否进行采样": "Enable GPT-2 sampling",
|
41 |
+
"生成Token最大数量,���小导致音频被截断": "Maximum number of tokens to generate. If text exceeds this, the audio will be cut off.",
|
42 |
+
"请上传情感参考音频": "Please upload the emotion reference audio",
|
43 |
+
"当前模型版本": "Current model version: ",
|
44 |
+
"请输入目标文本": "Please input the text to synthesize",
|
45 |
+
"例如:委屈巴巴、危险在悄悄逼近": "e.g. deeply sad, danger is creeping closer",
|
46 |
"与音色参考音频相同": "Same as the voice reference",
|
47 |
+
"情感随机采样": "Randomize emotion sampling",
|
48 |
+
"显示实验功能": "Show experimental features"
|
49 |
}
|
tools/i18n/locale/zh_CN.json
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
{
|
2 |
-
"
|
3 |
"如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录LICENSE.": "如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录LICENSE.",
|
4 |
"时长必须为正数": "时长必须为正数",
|
5 |
"请输入有效的浮点数": "请输入有效的浮点数",
|
@@ -17,7 +17,7 @@
|
|
17 |
"惊喜": "惊喜",
|
18 |
"平静": "平静",
|
19 |
"情感描述文本": "情感描述文本",
|
20 |
-
"
|
21 |
"高级生成参数设置": "高级生成参数设置",
|
22 |
"情感向量之和不能超过1.5,请调整后重试。": "情感向量之和不能超过1.5,请调整后重试。",
|
23 |
"音色参考音频": "音色参考音频",
|
@@ -36,5 +36,9 @@
|
|
36 |
"Token数": "Token数",
|
37 |
"情感控制方式": "情感控制方式",
|
38 |
"GPT2 采样设置": "GPT2 采样设置",
|
39 |
-
"参数会影响音频多样性和生成速度详见": "参数会影响音频多样性和生成速度详见"
|
|
|
|
|
|
|
|
|
40 |
}
|
|
|
1 |
{
|
2 |
+
"本软件以自拟协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责.": "本软件以自拟协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责.",
|
3 |
"如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录LICENSE.": "如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录LICENSE.",
|
4 |
"时长必须为正数": "时长必须为正数",
|
5 |
"请输入有效的浮点数": "请输入有效的浮点数",
|
|
|
17 |
"惊喜": "惊喜",
|
18 |
"平静": "平静",
|
19 |
"情感描述文本": "情感描述文本",
|
20 |
+
"请输入情绪描述(或留空以自动使用目标文本作为情绪描述)": "请输入情绪描述(或留空以自动使用目标文本作为情绪描述)",
|
21 |
"高级生成参数设置": "高级生成参数设置",
|
22 |
"情感向量之和不能超过1.5,请调整后重试。": "情感向量之和不能超过1.5,请调整后重试。",
|
23 |
"音色参考音频": "音色参考音频",
|
|
|
36 |
"Token数": "Token数",
|
37 |
"情感控制方式": "情感控制方式",
|
38 |
"GPT2 采样设置": "GPT2 采样设置",
|
39 |
+
"参数会影响音频多样性和生成速度详见": "参数会影响音频多样性和生成速度详见",
|
40 |
+
"是否进行采样": "是否进行采样",
|
41 |
+
"生成Token最大数量,过小导致音频被截断": "生成Token最大数量,过小导致音频被截断",
|
42 |
+
"显示实验功能": "显示实验功能",
|
43 |
+
"例如:委屈巴巴、危险在悄悄逼近": "例如:委屈巴巴、危险在悄悄逼近"
|
44 |
}
|
webui.py
CHANGED
@@ -1,6 +1,4 @@
|
|
1 |
import json
|
2 |
-
import logging
|
3 |
-
import spaces
|
4 |
import os
|
5 |
import sys
|
6 |
import threading
|
@@ -8,40 +6,60 @@ import time
|
|
8 |
|
9 |
import warnings
|
10 |
|
11 |
-
import
|
12 |
|
13 |
warnings.filterwarnings("ignore", category=FutureWarning)
|
14 |
warnings.filterwarnings("ignore", category=UserWarning)
|
15 |
|
|
|
|
|
16 |
current_dir = os.path.dirname(os.path.abspath(__file__))
|
17 |
sys.path.append(current_dir)
|
18 |
sys.path.append(os.path.join(current_dir, "indextts"))
|
19 |
|
20 |
import argparse
|
21 |
-
parser = argparse.ArgumentParser(
|
|
|
|
|
|
|
22 |
parser.add_argument("--verbose", action="store_true", default=False, help="Enable verbose mode")
|
23 |
parser.add_argument("--port", type=int, default=7860, help="Port to run the web UI on")
|
24 |
parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to run the web UI on")
|
25 |
-
parser.add_argument("--model_dir", type=str, default="checkpoints", help="Model checkpoints directory")
|
26 |
-
parser.add_argument("--
|
|
|
|
|
|
|
27 |
cmd_args = parser.parse_args()
|
28 |
|
29 |
-
|
30 |
-
|
31 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
|
33 |
import gradio as gr
|
34 |
-
from indextts import infer
|
35 |
from indextts.infer_v2 import IndexTTS2
|
36 |
from tools.i18n.i18n import I18nAuto
|
37 |
-
from modelscope.hub import api
|
38 |
|
39 |
i18n = I18nAuto(language="Auto")
|
40 |
MODE = 'local'
|
41 |
tts = IndexTTS2(model_dir=cmd_args.model_dir,
|
42 |
cfg_path=os.path.join(cmd_args.model_dir, "config.yaml"),
|
43 |
-
|
44 |
-
|
|
|
|
|
45 |
# 支持的语言列表
|
46 |
LANGUAGES = {
|
47 |
"中文": "zh_CN",
|
@@ -51,6 +69,9 @@ EMO_CHOICES = [i18n("与音色参考音频相同"),
|
|
51 |
i18n("使用情感参考音频"),
|
52 |
i18n("使用情感向量控制"),
|
53 |
i18n("使用情感描述文本控制")]
|
|
|
|
|
|
|
54 |
os.makedirs("outputs/tasks",exist_ok=True)
|
55 |
os.makedirs("prompts",exist_ok=True)
|
56 |
|
@@ -79,15 +100,23 @@ with open("examples/cases.jsonl", "r", encoding="utf-8") as f:
|
|
79 |
example.get("emo_vec_5",0),
|
80 |
example.get("emo_vec_6",0),
|
81 |
example.get("emo_vec_7",0),
|
82 |
-
example.get("emo_vec_8",0)
|
|
|
83 |
)
|
84 |
|
85 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
86 |
def gen_single(emo_control_method,prompt, text,
|
87 |
emo_ref_path, emo_weight,
|
88 |
vec1, vec2, vec3, vec4, vec5, vec6, vec7, vec8,
|
89 |
emo_text,emo_random,
|
90 |
-
|
91 |
*args, progress=gr.Progress()):
|
92 |
output_path = None
|
93 |
if not output_path:
|
@@ -110,28 +139,31 @@ def gen_single(emo_control_method,prompt, text,
|
|
110 |
}
|
111 |
if type(emo_control_method) is not int:
|
112 |
emo_control_method = emo_control_method.value
|
113 |
-
if emo_control_method == 0:
|
114 |
-
emo_ref_path = None
|
115 |
-
|
116 |
-
|
117 |
-
emo_weight = emo_weight
|
118 |
-
|
|
|
119 |
vec = [vec1, vec2, vec3, vec4, vec5, vec6, vec7, vec8]
|
120 |
-
|
121 |
-
if vec_sum > 1.5:
|
122 |
-
gr.Warning(i18n("情感向量之和不能超过1.5,请调整后重试。"))
|
123 |
-
return
|
124 |
else:
|
|
|
125 |
vec = None
|
126 |
|
127 |
-
|
|
|
|
|
|
|
|
|
128 |
output = tts.infer(spk_audio_prompt=prompt, text=text,
|
129 |
output_path=output_path,
|
130 |
emo_audio_prompt=emo_ref_path, emo_alpha=emo_weight,
|
131 |
emo_vector=vec,
|
132 |
use_emo_text=(emo_control_method==3), emo_text=emo_text,use_random=emo_random,
|
133 |
verbose=cmd_args.verbose,
|
134 |
-
|
135 |
**kwargs)
|
136 |
return gr.update(value=output,visible=True)
|
137 |
|
@@ -147,6 +179,7 @@ with gr.Blocks(title="IndexTTS Demo") as demo:
|
|
147 |
<a href='https://arxiv.org/abs/2506.21619'><img src='https://img.shields.io/badge/ArXiv-2506.21619-red'></a>
|
148 |
</p>
|
149 |
''')
|
|
|
150 |
with gr.Tab(i18n("音频生成")):
|
151 |
with gr.Row():
|
152 |
os.makedirs("prompts",exist_ok=True)
|
@@ -160,49 +193,54 @@ with gr.Blocks(title="IndexTTS Demo") as demo:
|
|
160 |
input_text_single = gr.TextArea(label=i18n("文本"),key="input_text_single", placeholder=i18n("请输入目标文本"), info=f"{i18n('当前模型版本')}{tts.model_version or '1.0'}")
|
161 |
gen_button = gr.Button(i18n("生成语音"), key="gen_button",interactive=True)
|
162 |
output_audio = gr.Audio(label=i18n("生成结果"), visible=True,key="output_audio")
|
|
|
163 |
with gr.Accordion(i18n("功能设置")):
|
164 |
# 情感控制选项部分
|
165 |
with gr.Row():
|
166 |
emo_control_method = gr.Radio(
|
167 |
-
choices=
|
168 |
type="index",
|
169 |
-
value=
|
170 |
# 情感参考音频部分
|
171 |
with gr.Group(visible=False) as emotion_reference_group:
|
172 |
with gr.Row():
|
173 |
emo_upload = gr.Audio(label=i18n("上传情感参考音频"), type="filepath")
|
174 |
|
175 |
-
with gr.Row():
|
176 |
-
emo_weight = gr.Slider(label=i18n("情感权重"), minimum=0.0, maximum=1.6, value=0.8, step=0.01)
|
177 |
-
|
178 |
# 情感随机采样
|
179 |
-
with gr.Row():
|
180 |
-
emo_random = gr.Checkbox(label=i18n("情感随机采样"),value=False
|
181 |
|
182 |
# 情感向量控制部分
|
183 |
with gr.Group(visible=False) as emotion_vector_group:
|
184 |
with gr.Row():
|
185 |
with gr.Column():
|
186 |
-
vec1 = gr.Slider(label=i18n("喜"), minimum=0.0, maximum=1.
|
187 |
-
vec2 = gr.Slider(label=i18n("怒"), minimum=0.0, maximum=1.
|
188 |
-
vec3 = gr.Slider(label=i18n("哀"), minimum=0.0, maximum=1.
|
189 |
-
vec4 = gr.Slider(label=i18n("惧"), minimum=0.0, maximum=1.
|
190 |
with gr.Column():
|
191 |
-
vec5 = gr.Slider(label=i18n("厌恶"), minimum=0.0, maximum=1.
|
192 |
-
vec6 = gr.Slider(label=i18n("低落"), minimum=0.0, maximum=1.
|
193 |
-
vec7 = gr.Slider(label=i18n("惊喜"), minimum=0.0, maximum=1.
|
194 |
-
vec8 = gr.Slider(label=i18n("平静"), minimum=0.0, maximum=1.
|
195 |
|
196 |
with gr.Group(visible=False) as emo_text_group:
|
197 |
with gr.Row():
|
198 |
-
emo_text = gr.Textbox(label=i18n("情感描述文本"),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
199 |
|
200 |
-
with gr.Accordion(i18n("高级生成参数设置"), open=False):
|
201 |
with gr.Row():
|
202 |
with gr.Column(scale=1):
|
203 |
-
gr.Markdown(f"**{i18n('GPT2 采样设置')}** _{i18n('参数会影响音频多样性和生成速度详见')}[Generation strategies](https://huggingface.co/docs/transformers/main/en/generation_strategies)_")
|
204 |
with gr.Row():
|
205 |
-
do_sample = gr.Checkbox(label="do_sample", value=True, info="是否进行采样")
|
206 |
temperature = gr.Slider(label="temperature", minimum=0.1, maximum=2.0, value=0.8, step=0.1)
|
207 |
with gr.Row():
|
208 |
top_p = gr.Slider(label="top_p", minimum=0.0, maximum=1.0, value=0.8, step=0.01)
|
@@ -211,21 +249,22 @@ with gr.Blocks(title="IndexTTS Demo") as demo:
|
|
211 |
with gr.Row():
|
212 |
repetition_penalty = gr.Number(label="repetition_penalty", precision=None, value=10.0, minimum=0.1, maximum=20.0, step=0.1)
|
213 |
length_penalty = gr.Number(label="length_penalty", precision=None, value=0.0, minimum=-2.0, maximum=2.0, step=0.1)
|
214 |
-
max_mel_tokens = gr.Slider(label="max_mel_tokens", value=1500, minimum=50, maximum=tts.cfg.gpt.max_mel_tokens, step=10, info="生成Token最大数量,过小导致音频被截断", key="max_mel_tokens")
|
215 |
# with gr.Row():
|
216 |
# typical_sampling = gr.Checkbox(label="typical_sampling", value=False, info="不建议使用")
|
217 |
# typical_mass = gr.Slider(label="typical_mass", value=0.9, minimum=0.0, maximum=1.0, step=0.1)
|
218 |
with gr.Column(scale=2):
|
219 |
gr.Markdown(f'**{i18n("分句设置")}** _{i18n("参数会影响音频质量和生成速度")}_')
|
220 |
with gr.Row():
|
221 |
-
|
222 |
-
|
|
|
223 |
info=i18n("建议80~200之间,值越大,分句越长;值越小,分句越碎;过小过大都可能导致音频质量不高"),
|
224 |
)
|
225 |
-
with gr.Accordion(i18n("预览分句结果"), open=True) as
|
226 |
-
|
227 |
headers=[i18n("序号"), i18n("分句内容"), i18n("Token数")],
|
228 |
-
key="
|
229 |
wrap=True,
|
230 |
)
|
231 |
advanced_params = [
|
@@ -234,8 +273,20 @@ with gr.Blocks(title="IndexTTS Demo") as demo:
|
|
234 |
# typical_sampling, typical_mass,
|
235 |
]
|
236 |
|
237 |
-
if len(example_cases) >
|
238 |
-
gr.Examples(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
239 |
examples=example_cases,
|
240 |
examples_per_page=20,
|
241 |
inputs=[prompt_audio,
|
@@ -244,71 +295,93 @@ with gr.Blocks(title="IndexTTS Demo") as demo:
|
|
244 |
emo_upload,
|
245 |
emo_weight,
|
246 |
emo_text,
|
247 |
-
vec1,vec2,vec3,vec4,vec5,vec6,vec7,vec8]
|
248 |
)
|
249 |
|
250 |
-
def on_input_text_change(text,
|
251 |
if text and len(text) > 0:
|
252 |
text_tokens_list = tts.tokenizer.tokenize(text)
|
253 |
|
254 |
-
|
255 |
data = []
|
256 |
-
for i, s in enumerate(
|
257 |
-
|
258 |
tokens_count = len(s)
|
259 |
-
data.append([i,
|
260 |
return {
|
261 |
-
|
262 |
}
|
263 |
else:
|
264 |
df = pd.DataFrame([], columns=[i18n("序号"), i18n("分句内容"), i18n("Token数")])
|
265 |
return {
|
266 |
-
|
267 |
}
|
|
|
268 |
def on_method_select(emo_control_method):
|
269 |
-
if emo_control_method == 1:
|
270 |
return (gr.update(visible=True),
|
271 |
gr.update(visible=False),
|
272 |
gr.update(visible=False),
|
273 |
-
gr.update(visible=False)
|
|
|
274 |
)
|
275 |
-
elif emo_control_method == 2:
|
276 |
return (gr.update(visible=False),
|
277 |
gr.update(visible=True),
|
278 |
gr.update(visible=True),
|
|
|
279 |
gr.update(visible=False)
|
280 |
)
|
281 |
-
elif emo_control_method == 3:
|
282 |
return (gr.update(visible=False),
|
283 |
gr.update(visible=True),
|
284 |
gr.update(visible=False),
|
|
|
285 |
gr.update(visible=True)
|
286 |
)
|
287 |
-
else:
|
288 |
return (gr.update(visible=False),
|
|
|
289 |
gr.update(visible=False),
|
290 |
gr.update(visible=False),
|
291 |
gr.update(visible=False)
|
292 |
)
|
293 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
294 |
emo_control_method.select(on_method_select,
|
295 |
inputs=[emo_control_method],
|
296 |
outputs=[emotion_reference_group,
|
297 |
-
|
298 |
emotion_vector_group,
|
299 |
-
emo_text_group
|
|
|
300 |
)
|
301 |
|
302 |
input_text_single.change(
|
303 |
on_input_text_change,
|
304 |
-
inputs=[input_text_single,
|
305 |
-
outputs=[
|
306 |
)
|
307 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
308 |
on_input_text_change,
|
309 |
-
inputs=[input_text_single,
|
310 |
-
outputs=[
|
311 |
)
|
|
|
312 |
prompt_audio.upload(update_prompt_audio,
|
313 |
inputs=[],
|
314 |
outputs=[gen_button])
|
@@ -317,7 +390,7 @@ with gr.Blocks(title="IndexTTS Demo") as demo:
|
|
317 |
inputs=[emo_control_method,prompt_audio, input_text_single, emo_upload, emo_weight,
|
318 |
vec1, vec2, vec3, vec4, vec5, vec6, vec7, vec8,
|
319 |
emo_text,emo_random,
|
320 |
-
|
321 |
*advanced_params,
|
322 |
],
|
323 |
outputs=[output_audio])
|
|
|
1 |
import json
|
|
|
|
|
2 |
import os
|
3 |
import sys
|
4 |
import threading
|
|
|
6 |
|
7 |
import warnings
|
8 |
|
9 |
+
import numpy as np
|
10 |
|
11 |
warnings.filterwarnings("ignore", category=FutureWarning)
|
12 |
warnings.filterwarnings("ignore", category=UserWarning)
|
13 |
|
14 |
+
import pandas as pd
|
15 |
+
|
16 |
current_dir = os.path.dirname(os.path.abspath(__file__))
|
17 |
sys.path.append(current_dir)
|
18 |
sys.path.append(os.path.join(current_dir, "indextts"))
|
19 |
|
20 |
import argparse
|
21 |
+
parser = argparse.ArgumentParser(
|
22 |
+
description="IndexTTS WebUI",
|
23 |
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
24 |
+
)
|
25 |
parser.add_argument("--verbose", action="store_true", default=False, help="Enable verbose mode")
|
26 |
parser.add_argument("--port", type=int, default=7860, help="Port to run the web UI on")
|
27 |
parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to run the web UI on")
|
28 |
+
parser.add_argument("--model_dir", type=str, default="./checkpoints", help="Model checkpoints directory")
|
29 |
+
parser.add_argument("--fp16", action="store_true", default=False, help="Use FP16 for inference if available")
|
30 |
+
parser.add_argument("--deepspeed", action="store_true", default=False, help="Use DeepSpeed to accelerate if available")
|
31 |
+
parser.add_argument("--cuda_kernel", action="store_true", default=False, help="Use CUDA kernel for inference if available")
|
32 |
+
parser.add_argument("--gui_seg_tokens", type=int, default=120, help="GUI: Max tokens per generation segment")
|
33 |
cmd_args = parser.parse_args()
|
34 |
|
35 |
+
if not os.path.exists(cmd_args.model_dir):
|
36 |
+
print(f"Model directory {cmd_args.model_dir} does not exist. Please download the model first.")
|
37 |
+
sys.exit(1)
|
38 |
+
|
39 |
+
for file in [
|
40 |
+
"bpe.model",
|
41 |
+
"gpt.pth",
|
42 |
+
"config.yaml",
|
43 |
+
"s2mel.pth",
|
44 |
+
"wav2vec2bert_stats.pt"
|
45 |
+
]:
|
46 |
+
file_path = os.path.join(cmd_args.model_dir, file)
|
47 |
+
if not os.path.exists(file_path):
|
48 |
+
print(f"Required file {file_path} does not exist. Please download it.")
|
49 |
+
sys.exit(1)
|
50 |
|
51 |
import gradio as gr
|
|
|
52 |
from indextts.infer_v2 import IndexTTS2
|
53 |
from tools.i18n.i18n import I18nAuto
|
|
|
54 |
|
55 |
i18n = I18nAuto(language="Auto")
|
56 |
MODE = 'local'
|
57 |
tts = IndexTTS2(model_dir=cmd_args.model_dir,
|
58 |
cfg_path=os.path.join(cmd_args.model_dir, "config.yaml"),
|
59 |
+
use_fp16=cmd_args.fp16,
|
60 |
+
use_deepspeed=cmd_args.deepspeed,
|
61 |
+
use_cuda_kernel=cmd_args.cuda_kernel,
|
62 |
+
)
|
63 |
# 支持的语言列表
|
64 |
LANGUAGES = {
|
65 |
"中文": "zh_CN",
|
|
|
69 |
i18n("使用情感参考音频"),
|
70 |
i18n("使用情感向量控制"),
|
71 |
i18n("使用情感描述文本控制")]
|
72 |
+
EMO_CHOICES_BASE = EMO_CHOICES[:3] # 基础选项
|
73 |
+
EMO_CHOICES_EXPERIMENTAL = EMO_CHOICES # 全部选项(包括文本描述)
|
74 |
+
|
75 |
os.makedirs("outputs/tasks",exist_ok=True)
|
76 |
os.makedirs("prompts",exist_ok=True)
|
77 |
|
|
|
100 |
example.get("emo_vec_5",0),
|
101 |
example.get("emo_vec_6",0),
|
102 |
example.get("emo_vec_7",0),
|
103 |
+
example.get("emo_vec_8",0),
|
104 |
+
example.get("emo_text") is not None]
|
105 |
)
|
106 |
|
107 |
+
def normalize_emo_vec(emo_vec):
|
108 |
+
# emotion factors for better user experience
|
109 |
+
k_vec = [0.75,0.70,0.80,0.80,0.75,0.75,0.55,0.45]
|
110 |
+
tmp = np.array(k_vec) * np.array(emo_vec)
|
111 |
+
if np.sum(tmp) > 0.8:
|
112 |
+
tmp = tmp * 0.8/ np.sum(tmp)
|
113 |
+
return tmp.tolist()
|
114 |
+
|
115 |
def gen_single(emo_control_method,prompt, text,
|
116 |
emo_ref_path, emo_weight,
|
117 |
vec1, vec2, vec3, vec4, vec5, vec6, vec7, vec8,
|
118 |
emo_text,emo_random,
|
119 |
+
max_text_tokens_per_segment=120,
|
120 |
*args, progress=gr.Progress()):
|
121 |
output_path = None
|
122 |
if not output_path:
|
|
|
139 |
}
|
140 |
if type(emo_control_method) is not int:
|
141 |
emo_control_method = emo_control_method.value
|
142 |
+
if emo_control_method == 0: # emotion from speaker
|
143 |
+
emo_ref_path = None # remove external reference audio
|
144 |
+
if emo_control_method == 1: # emotion from reference audio
|
145 |
+
# normalize emo_alpha for better user experience
|
146 |
+
emo_weight = emo_weight * 0.8
|
147 |
+
pass
|
148 |
+
if emo_control_method == 2: # emotion from custom vectors
|
149 |
vec = [vec1, vec2, vec3, vec4, vec5, vec6, vec7, vec8]
|
150 |
+
vec = normalize_emo_vec(vec)
|
|
|
|
|
|
|
151 |
else:
|
152 |
+
# don't use the emotion vector inputs for the other modes
|
153 |
vec = None
|
154 |
|
155 |
+
if emo_text == "":
|
156 |
+
# erase empty emotion descriptions; `infer()` will then automatically use the main prompt
|
157 |
+
emo_text = None
|
158 |
+
|
159 |
+
print(f"Emo control mode:{emo_control_method},weight:{emo_weight},vec:{vec}")
|
160 |
output = tts.infer(spk_audio_prompt=prompt, text=text,
|
161 |
output_path=output_path,
|
162 |
emo_audio_prompt=emo_ref_path, emo_alpha=emo_weight,
|
163 |
emo_vector=vec,
|
164 |
use_emo_text=(emo_control_method==3), emo_text=emo_text,use_random=emo_random,
|
165 |
verbose=cmd_args.verbose,
|
166 |
+
max_text_tokens_per_segment=int(max_text_tokens_per_segment),
|
167 |
**kwargs)
|
168 |
return gr.update(value=output,visible=True)
|
169 |
|
|
|
179 |
<a href='https://arxiv.org/abs/2506.21619'><img src='https://img.shields.io/badge/ArXiv-2506.21619-red'></a>
|
180 |
</p>
|
181 |
''')
|
182 |
+
|
183 |
with gr.Tab(i18n("音频生成")):
|
184 |
with gr.Row():
|
185 |
os.makedirs("prompts",exist_ok=True)
|
|
|
193 |
input_text_single = gr.TextArea(label=i18n("文本"),key="input_text_single", placeholder=i18n("请输入目标文本"), info=f"{i18n('当前模型版本')}{tts.model_version or '1.0'}")
|
194 |
gen_button = gr.Button(i18n("生成语音"), key="gen_button",interactive=True)
|
195 |
output_audio = gr.Audio(label=i18n("生成结果"), visible=True,key="output_audio")
|
196 |
+
experimental_checkbox = gr.Checkbox(label=i18n("显示实验功能"),value=False)
|
197 |
with gr.Accordion(i18n("功能设置")):
|
198 |
# 情感控制选项部分
|
199 |
with gr.Row():
|
200 |
emo_control_method = gr.Radio(
|
201 |
+
choices=EMO_CHOICES_BASE,
|
202 |
type="index",
|
203 |
+
value=EMO_CHOICES_BASE[0],label=i18n("情感控制方式"))
|
204 |
# 情感参考音频部分
|
205 |
with gr.Group(visible=False) as emotion_reference_group:
|
206 |
with gr.Row():
|
207 |
emo_upload = gr.Audio(label=i18n("上传情感参考音频"), type="filepath")
|
208 |
|
|
|
|
|
|
|
209 |
# 情感随机采样
|
210 |
+
with gr.Row(visible=False) as emotion_randomize_group:
|
211 |
+
emo_random = gr.Checkbox(label=i18n("情感随机采样"), value=False)
|
212 |
|
213 |
# 情感向量控制部分
|
214 |
with gr.Group(visible=False) as emotion_vector_group:
|
215 |
with gr.Row():
|
216 |
with gr.Column():
|
217 |
+
vec1 = gr.Slider(label=i18n("喜"), minimum=0.0, maximum=1.0, value=0.0, step=0.05)
|
218 |
+
vec2 = gr.Slider(label=i18n("怒"), minimum=0.0, maximum=1.0, value=0.0, step=0.05)
|
219 |
+
vec3 = gr.Slider(label=i18n("哀"), minimum=0.0, maximum=1.0, value=0.0, step=0.05)
|
220 |
+
vec4 = gr.Slider(label=i18n("惧"), minimum=0.0, maximum=1.0, value=0.0, step=0.05)
|
221 |
with gr.Column():
|
222 |
+
vec5 = gr.Slider(label=i18n("厌恶"), minimum=0.0, maximum=1.0, value=0.0, step=0.05)
|
223 |
+
vec6 = gr.Slider(label=i18n("低落"), minimum=0.0, maximum=1.0, value=0.0, step=0.05)
|
224 |
+
vec7 = gr.Slider(label=i18n("惊喜"), minimum=0.0, maximum=1.0, value=0.0, step=0.05)
|
225 |
+
vec8 = gr.Slider(label=i18n("平静"), minimum=0.0, maximum=1.0, value=0.0, step=0.05)
|
226 |
|
227 |
with gr.Group(visible=False) as emo_text_group:
|
228 |
with gr.Row():
|
229 |
+
emo_text = gr.Textbox(label=i18n("情感描述文本"),
|
230 |
+
placeholder=i18n("请输入情绪描述(或留空以自动使用目标文本作为情绪描述)"),
|
231 |
+
value="",
|
232 |
+
info=i18n("例如:委屈巴巴、危险在悄悄逼近"))
|
233 |
+
|
234 |
+
|
235 |
+
with gr.Row(visible=False) as emo_weight_group:
|
236 |
+
emo_weight = gr.Slider(label=i18n("情感权重"), minimum=0.0, maximum=1.0, value=0.8, step=0.01)
|
237 |
|
238 |
+
with gr.Accordion(i18n("高级生成参数设置"), open=False,visible=False) as advanced_settings_group:
|
239 |
with gr.Row():
|
240 |
with gr.Column(scale=1):
|
241 |
+
gr.Markdown(f"**{i18n('GPT2 采样设置')}** _{i18n('参数会影响音频多样性和生成速度详见')} [Generation strategies](https://huggingface.co/docs/transformers/main/en/generation_strategies)._")
|
242 |
with gr.Row():
|
243 |
+
do_sample = gr.Checkbox(label="do_sample", value=True, info=i18n("是否进行采样"))
|
244 |
temperature = gr.Slider(label="temperature", minimum=0.1, maximum=2.0, value=0.8, step=0.1)
|
245 |
with gr.Row():
|
246 |
top_p = gr.Slider(label="top_p", minimum=0.0, maximum=1.0, value=0.8, step=0.01)
|
|
|
249 |
with gr.Row():
|
250 |
repetition_penalty = gr.Number(label="repetition_penalty", precision=None, value=10.0, minimum=0.1, maximum=20.0, step=0.1)
|
251 |
length_penalty = gr.Number(label="length_penalty", precision=None, value=0.0, minimum=-2.0, maximum=2.0, step=0.1)
|
252 |
+
max_mel_tokens = gr.Slider(label="max_mel_tokens", value=1500, minimum=50, maximum=tts.cfg.gpt.max_mel_tokens, step=10, info=i18n("生成Token最大数量,过小导致音频被截断"), key="max_mel_tokens")
|
253 |
# with gr.Row():
|
254 |
# typical_sampling = gr.Checkbox(label="typical_sampling", value=False, info="不建议使用")
|
255 |
# typical_mass = gr.Slider(label="typical_mass", value=0.9, minimum=0.0, maximum=1.0, step=0.1)
|
256 |
with gr.Column(scale=2):
|
257 |
gr.Markdown(f'**{i18n("分句设置")}** _{i18n("参数会影响音频质量和生成速度")}_')
|
258 |
with gr.Row():
|
259 |
+
initial_value = max(20, min(tts.cfg.gpt.max_text_tokens, cmd_args.gui_seg_tokens))
|
260 |
+
max_text_tokens_per_segment = gr.Slider(
|
261 |
+
label=i18n("分句最大Token数"), value=initial_value, minimum=20, maximum=tts.cfg.gpt.max_text_tokens, step=2, key="max_text_tokens_per_segment",
|
262 |
info=i18n("建议80~200之间,值越大,分句越长;值越小,分句越碎;过小过大都可能导致音频质量不高"),
|
263 |
)
|
264 |
+
with gr.Accordion(i18n("预览分句结果"), open=True) as segments_settings:
|
265 |
+
segments_preview = gr.Dataframe(
|
266 |
headers=[i18n("序号"), i18n("分句内容"), i18n("Token数")],
|
267 |
+
key="segments_preview",
|
268 |
wrap=True,
|
269 |
)
|
270 |
advanced_params = [
|
|
|
273 |
# typical_sampling, typical_mass,
|
274 |
]
|
275 |
|
276 |
+
if len(example_cases) > 2:
|
277 |
+
example_table = gr.Examples(
|
278 |
+
examples=example_cases[:-2],
|
279 |
+
examples_per_page=20,
|
280 |
+
inputs=[prompt_audio,
|
281 |
+
emo_control_method,
|
282 |
+
input_text_single,
|
283 |
+
emo_upload,
|
284 |
+
emo_weight,
|
285 |
+
emo_text,
|
286 |
+
vec1,vec2,vec3,vec4,vec5,vec6,vec7,vec8,experimental_checkbox]
|
287 |
+
)
|
288 |
+
elif len(example_cases) > 0:
|
289 |
+
example_table = gr.Examples(
|
290 |
examples=example_cases,
|
291 |
examples_per_page=20,
|
292 |
inputs=[prompt_audio,
|
|
|
295 |
emo_upload,
|
296 |
emo_weight,
|
297 |
emo_text,
|
298 |
+
vec1, vec2, vec3, vec4, vec5, vec6, vec7, vec8, experimental_checkbox]
|
299 |
)
|
300 |
|
301 |
+
def on_input_text_change(text, max_text_tokens_per_segment):
|
302 |
if text and len(text) > 0:
|
303 |
text_tokens_list = tts.tokenizer.tokenize(text)
|
304 |
|
305 |
+
segments = tts.tokenizer.split_segments(text_tokens_list, max_text_tokens_per_segment=int(max_text_tokens_per_segment))
|
306 |
data = []
|
307 |
+
for i, s in enumerate(segments):
|
308 |
+
segment_str = ''.join(s)
|
309 |
tokens_count = len(s)
|
310 |
+
data.append([i, segment_str, tokens_count])
|
311 |
return {
|
312 |
+
segments_preview: gr.update(value=data, visible=True, type="array"),
|
313 |
}
|
314 |
else:
|
315 |
df = pd.DataFrame([], columns=[i18n("序号"), i18n("分句内容"), i18n("Token数")])
|
316 |
return {
|
317 |
+
segments_preview: gr.update(value=df),
|
318 |
}
|
319 |
+
|
320 |
def on_method_select(emo_control_method):
|
321 |
+
if emo_control_method == 1: # emotion reference audio
|
322 |
return (gr.update(visible=True),
|
323 |
gr.update(visible=False),
|
324 |
gr.update(visible=False),
|
325 |
+
gr.update(visible=False),
|
326 |
+
gr.update(visible=True)
|
327 |
)
|
328 |
+
elif emo_control_method == 2: # emotion vectors
|
329 |
return (gr.update(visible=False),
|
330 |
gr.update(visible=True),
|
331 |
gr.update(visible=True),
|
332 |
+
gr.update(visible=False),
|
333 |
gr.update(visible=False)
|
334 |
)
|
335 |
+
elif emo_control_method == 3: # emotion text description
|
336 |
return (gr.update(visible=False),
|
337 |
gr.update(visible=True),
|
338 |
gr.update(visible=False),
|
339 |
+
gr.update(visible=True),
|
340 |
gr.update(visible=True)
|
341 |
)
|
342 |
+
else: # 0: same as speaker voice
|
343 |
return (gr.update(visible=False),
|
344 |
+
gr.update(visible=False),
|
345 |
gr.update(visible=False),
|
346 |
gr.update(visible=False),
|
347 |
gr.update(visible=False)
|
348 |
)
|
349 |
|
350 |
+
def on_experimental_change(is_exp):
|
351 |
+
# 切换情感控制选项
|
352 |
+
# 第三个返回值实际没有起作用
|
353 |
+
if is_exp:
|
354 |
+
return gr.update(choices=EMO_CHOICES_EXPERIMENTAL, value=EMO_CHOICES_EXPERIMENTAL[0]), gr.update(visible=True),gr.update(value=example_cases)
|
355 |
+
else:
|
356 |
+
return gr.update(choices=EMO_CHOICES_BASE, value=EMO_CHOICES_BASE[0]), gr.update(visible=False),gr.update(value=example_cases[:-2])
|
357 |
+
|
358 |
emo_control_method.select(on_method_select,
|
359 |
inputs=[emo_control_method],
|
360 |
outputs=[emotion_reference_group,
|
361 |
+
emotion_randomize_group,
|
362 |
emotion_vector_group,
|
363 |
+
emo_text_group,
|
364 |
+
emo_weight_group]
|
365 |
)
|
366 |
|
367 |
input_text_single.change(
|
368 |
on_input_text_change,
|
369 |
+
inputs=[input_text_single, max_text_tokens_per_segment],
|
370 |
+
outputs=[segments_preview]
|
371 |
)
|
372 |
+
|
373 |
+
experimental_checkbox.change(
|
374 |
+
on_experimental_change,
|
375 |
+
inputs=[experimental_checkbox],
|
376 |
+
outputs=[emo_control_method, advanced_settings_group,example_table.dataset] # 高级参数Accordion
|
377 |
+
)
|
378 |
+
|
379 |
+
max_text_tokens_per_segment.change(
|
380 |
on_input_text_change,
|
381 |
+
inputs=[input_text_single, max_text_tokens_per_segment],
|
382 |
+
outputs=[segments_preview]
|
383 |
)
|
384 |
+
|
385 |
prompt_audio.upload(update_prompt_audio,
|
386 |
inputs=[],
|
387 |
outputs=[gen_button])
|
|
|
390 |
inputs=[emo_control_method,prompt_audio, input_text_single, emo_upload, emo_weight,
|
391 |
vec1, vec2, vec3, vec4, vec5, vec6, vec7, vec8,
|
392 |
emo_text,emo_random,
|
393 |
+
max_text_tokens_per_segment,
|
394 |
*advanced_params,
|
395 |
],
|
396 |
outputs=[output_audio])
|