kemuriririn commited on
Commit
c4fe16f
·
1 Parent(s): 04be12f

sync from github

Browse files
.gitattributes CHANGED
@@ -47,3 +47,4 @@ examples/emo_hate.wav filter=lfs diff=lfs merge=lfs -text
47
  examples/voice_01.wav filter=lfs diff=lfs merge=lfs -text
48
  examples/voice_03.wav filter=lfs diff=lfs merge=lfs -text
49
  examples/voice_04.wav filter=lfs diff=lfs merge=lfs -text
 
 
47
  examples/voice_01.wav filter=lfs diff=lfs merge=lfs -text
48
  examples/voice_03.wav filter=lfs diff=lfs merge=lfs -text
49
  examples/voice_04.wav filter=lfs diff=lfs merge=lfs -text
50
+ indextts/utils/maskgct/models/codec/facodec/modules/JDC/bst.t7 filter=lfs diff=lfs merge=lfs -text
.gitignore CHANGED
@@ -13,4 +13,5 @@ build/
13
  *.py[cod]
14
  *.egg-info/
15
  .venv
16
- checkpoints/*
 
 
13
  *.py[cod]
14
  *.egg-info/
15
  .venv
16
+ checkpoints/*
17
+ __MACOSX
examples/cases.jsonl CHANGED
@@ -1,12 +1,12 @@
1
- {"prompt_audio":"voice_01.wav","text":"Translate for mewhat is a surprise","emo_mode":0}
2
  {"prompt_audio":"voice_02.wav","text":"The palace is strict, no false rumors, Lady Qi!","emo_mode":0}
3
  {"prompt_audio":"voice_03.wav","text":"这个呀,就是我们精心制作准备的纪念品,大家可以看到这个色泽和这个材质啊,哎呀多么的光彩照人。","emo_mode":0}
4
  {"prompt_audio":"voice_04.wav","text":"你就需要我这种专业人士的帮助,就像手无缚鸡之力的人进入雪山狩猎,一定需要最老练的猎人指导。","emo_mode":0}
5
  {"prompt_audio":"voice_05.wav","text":"在真正的日本剑道中,格斗过程极其短暂,常常短至半秒,最长也不超过两秒,利剑相击的转瞬间,已有一方倒在血泊中。但在这电光石火的对决之前,双方都要以一个石雕般凝固的姿势站定,长时间的逼视对方,这一过程可能长达十分钟!","emo_mode":0}
6
  {"prompt_audio":"voice_06.wav","text":"今天呢,咱们开一部新书,叫《赛博朋克二零七七》。这词儿我听着都新鲜。这赛博朋克啊,简单理解就是“高科技,低生活”。这一听,我就明白了,于老师就爱用那高科技的东西,手机都得拿脚纹开,大冬天为了解锁脱得一丝不挂,冻得跟王八蛋似的。","emo_mode":0}
7
- {"prompt_audio":"voice_07.wav","emo_audio":"emo_sad.wav","emo_weight": 0.9, "emo_mode":1,"text":"酒楼丧尽天良,开始借机竞拍房间,哎,一群蠢货。"}
8
- {"prompt_audio":"voice_08.wav","emo_audio":"emo_hate.wav","emo_weight": 0.8, "emo_mode":1,"text":"你看看你,对我还有没有一点父子之间的信任了。"}
9
- {"prompt_audio":"voice_09.wav","emo_vec_3":0.55,"emo_mode":2,"text":"对不起嘛!我的记性真的不太好,但是和你在一起的事情,我都会努力记住的~"}
10
- {"prompt_audio":"voice_10.wav","emo_vec_7":0.45,"emo_mode":2,"text":"哇塞!这个爆率也太高了!欧皇附体了!"}
11
  {"prompt_audio":"voice_11.wav","emo_mode":3,"emo_text":"极度悲伤","text":"这些年的时光终究是错付了... "}
12
  {"prompt_audio":"voice_12.wav","emo_mode":3,"emo_text":"You scared me to death! What are you, a ghost?","text":"快躲起来!是他要来了!他要来抓我们了!"}
 
1
+ {"prompt_audio":"voice_01.wav","text":"Translate for me, what is a surprise!","emo_mode":0}
2
  {"prompt_audio":"voice_02.wav","text":"The palace is strict, no false rumors, Lady Qi!","emo_mode":0}
3
  {"prompt_audio":"voice_03.wav","text":"这个呀,就是我们精心制作准备的纪念品,大家可以看到这个色泽和这个材质啊,哎呀多么的光彩照人。","emo_mode":0}
4
  {"prompt_audio":"voice_04.wav","text":"你就需要我这种专业人士的帮助,就像手无缚鸡之力的人进入雪山狩猎,一定需要最老练的猎人指导。","emo_mode":0}
5
  {"prompt_audio":"voice_05.wav","text":"在真正的日本剑道中,格斗过程极其短暂,常常短至半秒,最长也不超过两秒,利剑相击的转瞬间,已有一方倒在血泊中。但在这电光石火的对决之前,双方都要以一个石雕般凝固的姿势站定,长时间的逼视对方,这一过程可能长达十分钟!","emo_mode":0}
6
  {"prompt_audio":"voice_06.wav","text":"今天呢,咱们开一部新书,叫《赛博朋克二零七七》。这词儿我听着都新鲜。这赛博朋克啊,简单理解就是“高科技,低生活”。这一听,我就明白了,于老师就爱用那高科技的东西,手机都得拿脚纹开,大冬天为了解锁脱得一丝不挂,冻得跟王八蛋似的。","emo_mode":0}
7
+ {"prompt_audio":"voice_07.wav","emo_audio":"emo_sad.wav","emo_weight": 1.0, "emo_mode":1,"text":"酒楼丧尽天良,开始借机竞拍房间,哎,一群蠢货。"}
8
+ {"prompt_audio":"voice_08.wav","emo_audio":"emo_hate.wav","emo_weight": 1.0, "emo_mode":1,"text":"你看看你,对我还有没有一点父子之间的信任了。"}
9
+ {"prompt_audio":"voice_09.wav","emo_vec_3":0.8,"emo_mode":2,"text":"对不起嘛!我的记性真的不太好,但是和你在一起的事情,我都会努力记住的~"}
10
+ {"prompt_audio":"voice_10.wav","emo_vec_7":1.0,"emo_mode":2,"text":"哇塞!这个爆率也太高了!欧皇附体了!"}
11
  {"prompt_audio":"voice_11.wav","emo_mode":3,"emo_text":"极度悲伤","text":"这些年的时光终究是错付了... "}
12
  {"prompt_audio":"voice_12.wav","emo_mode":3,"emo_text":"You scared me to death! What are you, a ghost?","text":"快躲起来!是他要来了!他要来抓我们了!"}
indextts/cli.py CHANGED
@@ -12,9 +12,9 @@ def main():
12
  parser.add_argument("-o", "--output_path", type=str, default="gen.wav", help="Path to the output wav file")
13
  parser.add_argument("-c", "--config", type=str, default="checkpoints/config.yaml", help="Path to the config file. Default is 'checkpoints/config.yaml'")
14
  parser.add_argument("--model_dir", type=str, default="checkpoints", help="Path to the model directory. Default is 'checkpoints'")
15
- parser.add_argument("--fp16", action="store_true", default=True, help="Use FP16 for inference if available")
16
  parser.add_argument("-f", "--force", action="store_true", default=False, help="Force to overwrite the output file if it exists")
17
- parser.add_argument("-d", "--device", type=str, default=None, help="Device to run the model on (cpu, cuda, mps)." )
18
  args = parser.parse_args()
19
  if len(args.text.strip()) == 0:
20
  print("ERROR: Text is empty.")
@@ -47,15 +47,18 @@ def main():
47
  if args.device is None:
48
  if torch.cuda.is_available():
49
  args.device = "cuda:0"
50
- elif torch.mps.is_available():
 
 
51
  args.device = "mps"
52
  else:
53
  args.device = "cpu"
54
  args.fp16 = False # Disable FP16 on CPU
55
  print("WARNING: Running on CPU may be slow.")
56
 
 
57
  from indextts.infer import IndexTTS
58
- tts = IndexTTS(cfg_path=args.config, model_dir=args.model_dir, is_fp16=args.fp16, device=args.device)
59
  tts.infer(audio_prompt=args.voice, text=args.text.strip(), output_path=output_path)
60
 
61
  if __name__ == "__main__":
 
12
  parser.add_argument("-o", "--output_path", type=str, default="gen.wav", help="Path to the output wav file")
13
  parser.add_argument("-c", "--config", type=str, default="checkpoints/config.yaml", help="Path to the config file. Default is 'checkpoints/config.yaml'")
14
  parser.add_argument("--model_dir", type=str, default="checkpoints", help="Path to the model directory. Default is 'checkpoints'")
15
+ parser.add_argument("--fp16", action="store_true", default=False, help="Use FP16 for inference if available")
16
  parser.add_argument("-f", "--force", action="store_true", default=False, help="Force to overwrite the output file if it exists")
17
+ parser.add_argument("-d", "--device", type=str, default=None, help="Device to run the model on (cpu, cuda, mps, xpu)." )
18
  args = parser.parse_args()
19
  if len(args.text.strip()) == 0:
20
  print("ERROR: Text is empty.")
 
47
  if args.device is None:
48
  if torch.cuda.is_available():
49
  args.device = "cuda:0"
50
+ elif hasattr(torch, "xpu") and torch.xpu.is_available():
51
+ args.device = "xpu"
52
+ elif hasattr(torch, "mps") and torch.mps.is_available():
53
  args.device = "mps"
54
  else:
55
  args.device = "cpu"
56
  args.fp16 = False # Disable FP16 on CPU
57
  print("WARNING: Running on CPU may be slow.")
58
 
59
+ # TODO: Add CLI support for IndexTTS2.
60
  from indextts.infer import IndexTTS
61
+ tts = IndexTTS(cfg_path=args.config, model_dir=args.model_dir, use_fp16=args.fp16, device=args.device)
62
  tts.infer(audio_prompt=args.voice, text=args.text.strip(), output_path=output_path)
63
 
64
  if __name__ == "__main__":
indextts/infer.py CHANGED
@@ -26,38 +26,42 @@ from indextts.utils.front import TextNormalizer, TextTokenizer
26
 
27
  class IndexTTS:
28
  def __init__(
29
- self, cfg_path="checkpoints/config.yaml", model_dir="checkpoints", is_fp16=True, device=None,
30
  use_cuda_kernel=None,
31
  ):
32
  """
33
  Args:
34
  cfg_path (str): path to the config file.
35
  model_dir (str): path to the model directory.
36
- is_fp16 (bool): whether to use fp16.
37
  device (str): device to use (e.g., 'cuda:0', 'cpu'). If None, it will be set automatically based on the availability of CUDA or MPS.
38
  use_cuda_kernel (None | bool): whether to use BigVGan custom fused activation CUDA kernel, only for CUDA device.
39
  """
40
  if device is not None:
41
  self.device = device
42
- self.is_fp16 = False if device == "cpu" else is_fp16
43
  self.use_cuda_kernel = use_cuda_kernel is not None and use_cuda_kernel and device.startswith("cuda")
44
  elif torch.cuda.is_available():
45
  self.device = "cuda:0"
46
- self.is_fp16 = is_fp16
47
  self.use_cuda_kernel = use_cuda_kernel is None or use_cuda_kernel
 
 
 
 
48
  elif hasattr(torch, "mps") and torch.backends.mps.is_available():
49
  self.device = "mps"
50
- self.is_fp16 = False # Use float16 on MPS is overhead than float32
51
  self.use_cuda_kernel = False
52
  else:
53
  self.device = "cpu"
54
- self.is_fp16 = False
55
  self.use_cuda_kernel = False
56
  print(">> Be patient, it may take a while to run in CPU mode.")
57
 
58
  self.cfg = OmegaConf.load(cfg_path)
59
  self.model_dir = model_dir
60
- self.dtype = torch.float16 if self.is_fp16 else None
61
  self.stop_mel_token = self.cfg.gpt.stop_mel_token
62
 
63
  # Comment-off to load the VQ-VAE model for debugging tokenizer
@@ -68,7 +72,7 @@ class IndexTTS:
68
  # self.dvae_path = os.path.join(self.model_dir, self.cfg.dvae_checkpoint)
69
  # load_checkpoint(self.dvae, self.dvae_path)
70
  # self.dvae = self.dvae.to(self.device)
71
- # if self.is_fp16:
72
  # self.dvae.eval().half()
73
  # else:
74
  # self.dvae.eval()
@@ -77,12 +81,12 @@ class IndexTTS:
77
  self.gpt_path = os.path.join(self.model_dir, self.cfg.gpt_checkpoint)
78
  load_checkpoint(self.gpt, self.gpt_path)
79
  self.gpt = self.gpt.to(self.device)
80
- if self.is_fp16:
81
  self.gpt.eval().half()
82
  else:
83
  self.gpt.eval()
84
  print(">> GPT weights restored from:", self.gpt_path)
85
- if self.is_fp16:
86
  try:
87
  import deepspeed
88
 
@@ -184,17 +188,17 @@ class IndexTTS:
184
  code_lens = torch.tensor(code_lens, dtype=torch.long, device=device)
185
  return codes, code_lens
186
 
187
- def bucket_sentences(self, sentences, bucket_max_size=4) -> List[List[Dict]]:
188
  """
189
- Sentence data bucketing.
190
- if ``bucket_max_size=1``, return all sentences in one bucket.
191
  """
192
  outputs: List[Dict] = []
193
- for idx, sent in enumerate(sentences):
194
  outputs.append({"idx": idx, "sent": sent, "len": len(sent)})
195
 
196
  if len(outputs) > bucket_max_size:
197
- # split sentences into buckets by sentence length
198
  buckets: List[List[Dict]] = []
199
  factor = 1.5
200
  last_bucket = None
@@ -203,7 +207,7 @@ class IndexTTS:
203
  for sent in sorted(outputs, key=lambda x: x["len"]):
204
  current_sent_len = sent["len"]
205
  if current_sent_len == 0:
206
- print(">> skip empty sentence")
207
  continue
208
  if last_bucket is None \
209
  or current_sent_len >= int(last_bucket_sent_len_median * factor) \
@@ -213,7 +217,7 @@ class IndexTTS:
213
  last_bucket = buckets[-1]
214
  last_bucket_sent_len_median = current_sent_len
215
  else:
216
- # current bucket can hold more sentences
217
  last_bucket.append(sent) # sorted
218
  mid = len(last_bucket) // 2
219
  last_bucket_sent_len_median = last_bucket[mid]["len"]
@@ -276,20 +280,20 @@ class IndexTTS:
276
  self.gr_progress(value, desc=desc)
277
 
278
  # 快速推理:对于“多句长文本”,可实现至少 2~10 倍以上的速度提升~ (First modified by sunnyboxs 2025-04-16)
279
- def infer_fast(self, audio_prompt, text, output_path, verbose=False, max_text_tokens_per_sentence=100,
280
- sentences_bucket_max_size=4, **generation_kwargs):
281
  """
282
  Args:
283
- ``max_text_tokens_per_sentence``: 分句的最大token数,默认``100``,可以根据GPU硬件情况调整
284
  - 越小,batch 越多,推理速度越*快*,占用内存更多,可能影响质量
285
  - 越大,batch 越少,推理速度越*慢*,占用内存和质量更接近于非快速推理
286
- ``sentences_bucket_max_size``: 分句分桶的最大容量,默认``4``,可以根据GPU内存调整
287
  - 越大,bucket数量越少,batch越多,推理速度越*快*,占用内存更多,可能影响质量
288
  - 越小,bucket数量越多,batch越少,推理速度越*慢*,占用内存和质量更接近于非快速推理
289
  """
290
- print(">> start fast inference...")
291
 
292
- self._set_gr_progress(0, "start fast inference...")
293
  if verbose:
294
  print(f"origin text:{text}")
295
  start_time = time.perf_counter()
@@ -301,6 +305,15 @@ class IndexTTS:
301
  if audio.shape[0] > 1:
302
  audio = audio[0].unsqueeze(0)
303
  audio = torchaudio.transforms.Resample(sr, 24000)(audio)
 
 
 
 
 
 
 
 
 
304
  cond_mel = MelSpectrogramFeatures()(audio).to(self.device)
305
  cond_mel_frame = cond_mel.shape[-1]
306
  if verbose:
@@ -319,13 +332,13 @@ class IndexTTS:
319
  # text_tokens
320
  text_tokens_list = self.tokenizer.tokenize(text)
321
 
322
- sentences = self.tokenizer.split_sentences(text_tokens_list,
323
- max_tokens_per_sentence=max_text_tokens_per_sentence)
324
  if verbose:
325
  print(">> text token count:", len(text_tokens_list))
326
- print(" splited sentences count:", len(sentences))
327
- print(" max_text_tokens_per_sentence:", max_text_tokens_per_sentence)
328
- print(*sentences, sep="\n")
329
  do_sample = generation_kwargs.pop("do_sample", True)
330
  top_p = generation_kwargs.pop("top_p", 0.8)
331
  top_k = generation_kwargs.pop("top_k", 30)
@@ -346,17 +359,17 @@ class IndexTTS:
346
  # text processing
347
  all_text_tokens: List[List[torch.Tensor]] = []
348
  self._set_gr_progress(0.1, "text processing...")
349
- bucket_max_size = sentences_bucket_max_size if self.device != "cpu" else 1
350
- all_sentences = self.bucket_sentences(sentences, bucket_max_size=bucket_max_size)
351
- bucket_count = len(all_sentences)
352
  if verbose:
353
- print(">> sentences bucket_count:", bucket_count,
354
- "bucket sizes:", [(len(s), [t["idx"] for t in s]) for s in all_sentences],
355
  "bucket_max_size:", bucket_max_size)
356
- for sentences in all_sentences:
357
  temp_tokens: List[torch.Tensor] = []
358
  all_text_tokens.append(temp_tokens)
359
- for item in sentences:
360
  sent = item["sent"]
361
  text_tokens = self.tokenizer.convert_tokens_to_ids(sent)
362
  text_tokens = torch.tensor(text_tokens, dtype=torch.int32, device=self.device).unsqueeze(0)
@@ -365,11 +378,11 @@ class IndexTTS:
365
  print(f"text_tokens shape: {text_tokens.shape}, text_tokens type: {text_tokens.dtype}")
366
  # debug tokenizer
367
  text_token_syms = self.tokenizer.convert_ids_to_tokens(text_tokens[0].tolist())
368
- print("text_token_syms is same as sentence tokens", text_token_syms == sent)
369
  temp_tokens.append(text_tokens)
370
 
371
  # Sequential processing of bucketing data
372
- all_batch_num = sum(len(s) for s in all_sentences)
373
  all_batch_codes = []
374
  processed_num = 0
375
  for item_tokens in all_text_tokens:
@@ -381,7 +394,7 @@ class IndexTTS:
381
  processed_num += batch_num
382
  # gpt speech
383
  self._set_gr_progress(0.2 + 0.3 * processed_num / all_batch_num,
384
- f"gpt inference speech... {processed_num}/{all_batch_num}")
385
  m_start_time = time.perf_counter()
386
  with torch.no_grad():
387
  with torch.amp.autocast(batch_text_tokens.device.type, enabled=self.dtype is not None,
@@ -403,17 +416,17 @@ class IndexTTS:
403
  gpt_gen_time += time.perf_counter() - m_start_time
404
 
405
  # gpt latent
406
- self._set_gr_progress(0.5, "gpt inference latents...")
407
  all_idxs = []
408
  all_latents = []
409
  has_warned = False
410
- for batch_codes, batch_tokens, batch_sentences in zip(all_batch_codes, all_text_tokens, all_sentences):
411
  for i in range(batch_codes.shape[0]):
412
  codes = batch_codes[i] # [x]
413
  if not has_warned and codes[-1] != self.stop_mel_token:
414
  warnings.warn(
415
  f"WARN: generation stopped due to exceeding `max_mel_tokens` ({max_mel_tokens}). "
416
- f"Consider reducing `max_text_tokens_per_sentence`({max_text_tokens_per_sentence}) or increasing `max_mel_tokens`.",
417
  category=RuntimeWarning
418
  )
419
  has_warned = True
@@ -427,7 +440,7 @@ class IndexTTS:
427
  print(codes)
428
  print("code_lens:", code_lens)
429
  text_tokens = batch_tokens[i]
430
- all_idxs.append(batch_sentences[i]["idx"])
431
  m_start_time = time.perf_counter()
432
  with torch.no_grad():
433
  with torch.amp.autocast(text_tokens.device.type, enabled=self.dtype is not None, dtype=self.dtype):
@@ -440,7 +453,7 @@ class IndexTTS:
440
  return_latent=True, clip_inputs=False)
441
  gpt_forward_time += time.perf_counter() - m_start_time
442
  all_latents.append(latent)
443
- del all_batch_codes, all_text_tokens, all_sentences
444
  # bigvgan chunk
445
  chunk_size = 2
446
  all_latents = [all_latents[all_idxs.index(i)] for i in range(len(all_latents))]
@@ -452,7 +465,7 @@ class IndexTTS:
452
  latent_length = len(all_latents)
453
 
454
  # bigvgan chunk decode
455
- self._set_gr_progress(0.7, "bigvgan decode...")
456
  tqdm_progress = tqdm(total=latent_length, desc="bigvgan")
457
  for items in chunk_latents:
458
  tqdm_progress.update(len(items))
@@ -474,7 +487,7 @@ class IndexTTS:
474
  self.torch_empty_cache()
475
 
476
  # wav audio output
477
- self._set_gr_progress(0.9, "save audio...")
478
  wav = torch.cat(wavs, dim=1)
479
  wav_length = wav.shape[-1] / sampling_rate
480
  print(f">> Reference audio length: {cond_mel_frame * 256 / sampling_rate:.2f} seconds")
@@ -503,10 +516,10 @@ class IndexTTS:
503
  return (sampling_rate, wav_data)
504
 
505
  # 原始推理模式
506
- def infer(self, audio_prompt, text, output_path, verbose=False, max_text_tokens_per_sentence=120,
507
  **generation_kwargs):
508
- print(">> start inference...")
509
- self._set_gr_progress(0, "start inference...")
510
  if verbose:
511
  print(f"origin text:{text}")
512
  start_time = time.perf_counter()
@@ -533,12 +546,12 @@ class IndexTTS:
533
  self._set_gr_progress(0.1, "text processing...")
534
  auto_conditioning = cond_mel
535
  text_tokens_list = self.tokenizer.tokenize(text)
536
- sentences = self.tokenizer.split_sentences(text_tokens_list, max_text_tokens_per_sentence)
537
  if verbose:
538
  print("text token count:", len(text_tokens_list))
539
- print("sentences count:", len(sentences))
540
- print("max_text_tokens_per_sentence:", max_text_tokens_per_sentence)
541
- print(*sentences, sep="\n")
542
  do_sample = generation_kwargs.pop("do_sample", True)
543
  top_p = generation_kwargs.pop("top_p", 0.8)
544
  top_k = generation_kwargs.pop("top_k", 30)
@@ -557,7 +570,7 @@ class IndexTTS:
557
  bigvgan_time = 0
558
  progress = 0
559
  has_warned = False
560
- for sent in sentences:
561
  text_tokens = self.tokenizer.convert_tokens_to_ids(sent)
562
  text_tokens = torch.tensor(text_tokens, dtype=torch.int32, device=self.device).unsqueeze(0)
563
  # text_tokens = F.pad(text_tokens, (0, 1)) # This may not be necessary.
@@ -568,13 +581,13 @@ class IndexTTS:
568
  print(f"text_tokens shape: {text_tokens.shape}, text_tokens type: {text_tokens.dtype}")
569
  # debug tokenizer
570
  text_token_syms = self.tokenizer.convert_ids_to_tokens(text_tokens[0].tolist())
571
- print("text_token_syms is same as sentence tokens", text_token_syms == sent)
572
 
573
  # text_len = torch.IntTensor([text_tokens.size(1)], device=text_tokens.device)
574
  # print(text_len)
575
  progress += 1
576
- self._set_gr_progress(0.2 + 0.4 * (progress - 1) / len(sentences),
577
- f"gpt inference latent... {progress}/{len(sentences)}")
578
  m_start_time = time.perf_counter()
579
  with torch.no_grad():
580
  with torch.amp.autocast(text_tokens.device.type, enabled=self.dtype is not None, dtype=self.dtype):
@@ -597,7 +610,7 @@ class IndexTTS:
597
  warnings.warn(
598
  f"WARN: generation stopped due to exceeding `max_mel_tokens` ({max_mel_tokens}). "
599
  f"Input text tokens: {text_tokens.shape[1]}. "
600
- f"Consider reducing `max_text_tokens_per_sentence`({max_text_tokens_per_sentence}) or increasing `max_mel_tokens`.",
601
  category=RuntimeWarning
602
  )
603
  has_warned = True
@@ -615,8 +628,8 @@ class IndexTTS:
615
  print(codes, type(codes))
616
  print(f"fix codes shape: {codes.shape}, codes type: {codes.dtype}")
617
  print(f"code len: {code_lens}")
618
- self._set_gr_progress(0.2 + 0.4 * progress / len(sentences),
619
- f"gpt inference speech... {progress}/{len(sentences)}")
620
  m_start_time = time.perf_counter()
621
  # latent, text_lens_out, code_lens_out = \
622
  with torch.amp.autocast(text_tokens.device.type, enabled=self.dtype is not None, dtype=self.dtype):
@@ -640,7 +653,7 @@ class IndexTTS:
640
  # wavs.append(wav[:, :-512])
641
  wavs.append(wav.cpu()) # to cpu before saving
642
  end_time = time.perf_counter()
643
- self._set_gr_progress(0.9, "save audio...")
644
  wav = torch.cat(wavs, dim=1)
645
  wav_length = wav.shape[-1] / sampling_rate
646
  print(f">> Reference audio length: {cond_mel_frame * 256 / sampling_rate:.2f} seconds")
 
26
 
27
  class IndexTTS:
28
  def __init__(
29
+ self, cfg_path="checkpoints/config.yaml", model_dir="checkpoints", use_fp16=True, device=None,
30
  use_cuda_kernel=None,
31
  ):
32
  """
33
  Args:
34
  cfg_path (str): path to the config file.
35
  model_dir (str): path to the model directory.
36
+ use_fp16 (bool): whether to use fp16.
37
  device (str): device to use (e.g., 'cuda:0', 'cpu'). If None, it will be set automatically based on the availability of CUDA or MPS.
38
  use_cuda_kernel (None | bool): whether to use BigVGan custom fused activation CUDA kernel, only for CUDA device.
39
  """
40
  if device is not None:
41
  self.device = device
42
+ self.use_fp16 = False if device == "cpu" else use_fp16
43
  self.use_cuda_kernel = use_cuda_kernel is not None and use_cuda_kernel and device.startswith("cuda")
44
  elif torch.cuda.is_available():
45
  self.device = "cuda:0"
46
+ self.use_fp16 = use_fp16
47
  self.use_cuda_kernel = use_cuda_kernel is None or use_cuda_kernel
48
+ elif hasattr(torch, "xpu") and torch.xpu.is_available():
49
+ self.device = "xpu"
50
+ self.use_fp16 = use_fp16
51
+ self.use_cuda_kernel = False
52
  elif hasattr(torch, "mps") and torch.backends.mps.is_available():
53
  self.device = "mps"
54
+ self.use_fp16 = False # Use float16 on MPS is overhead than float32
55
  self.use_cuda_kernel = False
56
  else:
57
  self.device = "cpu"
58
+ self.use_fp16 = False
59
  self.use_cuda_kernel = False
60
  print(">> Be patient, it may take a while to run in CPU mode.")
61
 
62
  self.cfg = OmegaConf.load(cfg_path)
63
  self.model_dir = model_dir
64
+ self.dtype = torch.float16 if self.use_fp16 else None
65
  self.stop_mel_token = self.cfg.gpt.stop_mel_token
66
 
67
  # Comment-off to load the VQ-VAE model for debugging tokenizer
 
72
  # self.dvae_path = os.path.join(self.model_dir, self.cfg.dvae_checkpoint)
73
  # load_checkpoint(self.dvae, self.dvae_path)
74
  # self.dvae = self.dvae.to(self.device)
75
+ # if self.use_fp16:
76
  # self.dvae.eval().half()
77
  # else:
78
  # self.dvae.eval()
 
81
  self.gpt_path = os.path.join(self.model_dir, self.cfg.gpt_checkpoint)
82
  load_checkpoint(self.gpt, self.gpt_path)
83
  self.gpt = self.gpt.to(self.device)
84
+ if self.use_fp16:
85
  self.gpt.eval().half()
86
  else:
87
  self.gpt.eval()
88
  print(">> GPT weights restored from:", self.gpt_path)
89
+ if self.use_fp16:
90
  try:
91
  import deepspeed
92
 
 
188
  code_lens = torch.tensor(code_lens, dtype=torch.long, device=device)
189
  return codes, code_lens
190
 
191
+ def bucket_segments(self, segments, bucket_max_size=4) -> List[List[Dict]]:
192
  """
193
+ Segment data bucketing.
194
+ if ``bucket_max_size=1``, return all segments in one bucket.
195
  """
196
  outputs: List[Dict] = []
197
+ for idx, sent in enumerate(segments):
198
  outputs.append({"idx": idx, "sent": sent, "len": len(sent)})
199
 
200
  if len(outputs) > bucket_max_size:
201
+ # split segments into buckets by segment length
202
  buckets: List[List[Dict]] = []
203
  factor = 1.5
204
  last_bucket = None
 
207
  for sent in sorted(outputs, key=lambda x: x["len"]):
208
  current_sent_len = sent["len"]
209
  if current_sent_len == 0:
210
+ print(">> skip empty segment")
211
  continue
212
  if last_bucket is None \
213
  or current_sent_len >= int(last_bucket_sent_len_median * factor) \
 
217
  last_bucket = buckets[-1]
218
  last_bucket_sent_len_median = current_sent_len
219
  else:
220
+ # current bucket can hold more segments
221
  last_bucket.append(sent) # sorted
222
  mid = len(last_bucket) // 2
223
  last_bucket_sent_len_median = last_bucket[mid]["len"]
 
280
  self.gr_progress(value, desc=desc)
281
 
282
  # 快速推理:对于“多句长文本”,可实现至少 2~10 倍以上的速度提升~ (First modified by sunnyboxs 2025-04-16)
283
+ def infer_fast(self, audio_prompt, text, output_path, verbose=False, max_text_tokens_per_segment=100,
284
+ segments_bucket_max_size=4, **generation_kwargs):
285
  """
286
  Args:
287
+ ``max_text_tokens_per_segment``: 分句的最大token数,默认``100``,可以根据GPU硬件情况调整
288
  - 越小,batch 越多,推理速度越*快*,占用内存更多,可能影响质量
289
  - 越大,batch 越少,推理速度越*慢*,占用内存和质量更接近于非快速推理
290
+ ``segments_bucket_max_size``: 分句分桶的最大容量,默认``4``,可以根据GPU内存调整
291
  - 越大,bucket数量越少,batch越多,推理速度越*快*,占用内存更多,可能影响质量
292
  - 越小,bucket数量越多,batch越少,推理速度越*慢*,占用内存和质量更接近于非快速推理
293
  """
294
+ print(">> starting fast inference...")
295
 
296
+ self._set_gr_progress(0, "starting fast inference...")
297
  if verbose:
298
  print(f"origin text:{text}")
299
  start_time = time.perf_counter()
 
305
  if audio.shape[0] > 1:
306
  audio = audio[0].unsqueeze(0)
307
  audio = torchaudio.transforms.Resample(sr, 24000)(audio)
308
+
309
+ max_audio_length_seconds = 50
310
+ max_audio_samples = int(max_audio_length_seconds * 24000)
311
+
312
+ if audio.shape[1] > max_audio_samples:
313
+ if verbose:
314
+ print(f"Audio too long ({audio.shape[1]} samples), truncating to {max_audio_samples} samples")
315
+ audio = audio[:, :max_audio_samples]
316
+
317
  cond_mel = MelSpectrogramFeatures()(audio).to(self.device)
318
  cond_mel_frame = cond_mel.shape[-1]
319
  if verbose:
 
332
  # text_tokens
333
  text_tokens_list = self.tokenizer.tokenize(text)
334
 
335
+ segments = self.tokenizer.split_segments(text_tokens_list,
336
+ max_text_tokens_per_segment=max_text_tokens_per_segment)
337
  if verbose:
338
  print(">> text token count:", len(text_tokens_list))
339
+ print(" segments count:", len(segments))
340
+ print(" max_text_tokens_per_segment:", max_text_tokens_per_segment)
341
+ print(*segments, sep="\n")
342
  do_sample = generation_kwargs.pop("do_sample", True)
343
  top_p = generation_kwargs.pop("top_p", 0.8)
344
  top_k = generation_kwargs.pop("top_k", 30)
 
359
  # text processing
360
  all_text_tokens: List[List[torch.Tensor]] = []
361
  self._set_gr_progress(0.1, "text processing...")
362
+ bucket_max_size = segments_bucket_max_size if self.device != "cpu" else 1
363
+ all_segments = self.bucket_segments(segments, bucket_max_size=bucket_max_size)
364
+ bucket_count = len(all_segments)
365
  if verbose:
366
+ print(">> segments bucket_count:", bucket_count,
367
+ "bucket sizes:", [(len(s), [t["idx"] for t in s]) for s in all_segments],
368
  "bucket_max_size:", bucket_max_size)
369
+ for segments in all_segments:
370
  temp_tokens: List[torch.Tensor] = []
371
  all_text_tokens.append(temp_tokens)
372
+ for item in segments:
373
  sent = item["sent"]
374
  text_tokens = self.tokenizer.convert_tokens_to_ids(sent)
375
  text_tokens = torch.tensor(text_tokens, dtype=torch.int32, device=self.device).unsqueeze(0)
 
378
  print(f"text_tokens shape: {text_tokens.shape}, text_tokens type: {text_tokens.dtype}")
379
  # debug tokenizer
380
  text_token_syms = self.tokenizer.convert_ids_to_tokens(text_tokens[0].tolist())
381
+ print("text_token_syms is same as segment tokens", text_token_syms == sent)
382
  temp_tokens.append(text_tokens)
383
 
384
  # Sequential processing of bucketing data
385
+ all_batch_num = sum(len(s) for s in all_segments)
386
  all_batch_codes = []
387
  processed_num = 0
388
  for item_tokens in all_text_tokens:
 
394
  processed_num += batch_num
395
  # gpt speech
396
  self._set_gr_progress(0.2 + 0.3 * processed_num / all_batch_num,
397
+ f"gpt speech inference {processed_num}/{all_batch_num}...")
398
  m_start_time = time.perf_counter()
399
  with torch.no_grad():
400
  with torch.amp.autocast(batch_text_tokens.device.type, enabled=self.dtype is not None,
 
416
  gpt_gen_time += time.perf_counter() - m_start_time
417
 
418
  # gpt latent
419
+ self._set_gr_progress(0.5, "gpt latents inference...")
420
  all_idxs = []
421
  all_latents = []
422
  has_warned = False
423
+ for batch_codes, batch_tokens, batch_segments in zip(all_batch_codes, all_text_tokens, all_segments):
424
  for i in range(batch_codes.shape[0]):
425
  codes = batch_codes[i] # [x]
426
  if not has_warned and codes[-1] != self.stop_mel_token:
427
  warnings.warn(
428
  f"WARN: generation stopped due to exceeding `max_mel_tokens` ({max_mel_tokens}). "
429
+ f"Consider reducing `max_text_tokens_per_segment`({max_text_tokens_per_segment}) or increasing `max_mel_tokens`.",
430
  category=RuntimeWarning
431
  )
432
  has_warned = True
 
440
  print(codes)
441
  print("code_lens:", code_lens)
442
  text_tokens = batch_tokens[i]
443
+ all_idxs.append(batch_segments[i]["idx"])
444
  m_start_time = time.perf_counter()
445
  with torch.no_grad():
446
  with torch.amp.autocast(text_tokens.device.type, enabled=self.dtype is not None, dtype=self.dtype):
 
453
  return_latent=True, clip_inputs=False)
454
  gpt_forward_time += time.perf_counter() - m_start_time
455
  all_latents.append(latent)
456
+ del all_batch_codes, all_text_tokens, all_segments
457
  # bigvgan chunk
458
  chunk_size = 2
459
  all_latents = [all_latents[all_idxs.index(i)] for i in range(len(all_latents))]
 
465
  latent_length = len(all_latents)
466
 
467
  # bigvgan chunk decode
468
+ self._set_gr_progress(0.7, "bigvgan decoding...")
469
  tqdm_progress = tqdm(total=latent_length, desc="bigvgan")
470
  for items in chunk_latents:
471
  tqdm_progress.update(len(items))
 
487
  self.torch_empty_cache()
488
 
489
  # wav audio output
490
+ self._set_gr_progress(0.9, "saving audio...")
491
  wav = torch.cat(wavs, dim=1)
492
  wav_length = wav.shape[-1] / sampling_rate
493
  print(f">> Reference audio length: {cond_mel_frame * 256 / sampling_rate:.2f} seconds")
 
516
  return (sampling_rate, wav_data)
517
 
518
  # 原始推理模式
519
+ def infer(self, audio_prompt, text, output_path, verbose=False, max_text_tokens_per_segment=120,
520
  **generation_kwargs):
521
+ print(">> starting inference...")
522
+ self._set_gr_progress(0, "starting inference...")
523
  if verbose:
524
  print(f"origin text:{text}")
525
  start_time = time.perf_counter()
 
546
  self._set_gr_progress(0.1, "text processing...")
547
  auto_conditioning = cond_mel
548
  text_tokens_list = self.tokenizer.tokenize(text)
549
+ segments = self.tokenizer.split_segments(text_tokens_list, max_text_tokens_per_segment)
550
  if verbose:
551
  print("text token count:", len(text_tokens_list))
552
+ print("segments count:", len(segments))
553
+ print("max_text_tokens_per_segment:", max_text_tokens_per_segment)
554
+ print(*segments, sep="\n")
555
  do_sample = generation_kwargs.pop("do_sample", True)
556
  top_p = generation_kwargs.pop("top_p", 0.8)
557
  top_k = generation_kwargs.pop("top_k", 30)
 
570
  bigvgan_time = 0
571
  progress = 0
572
  has_warned = False
573
+ for sent in segments:
574
  text_tokens = self.tokenizer.convert_tokens_to_ids(sent)
575
  text_tokens = torch.tensor(text_tokens, dtype=torch.int32, device=self.device).unsqueeze(0)
576
  # text_tokens = F.pad(text_tokens, (0, 1)) # This may not be necessary.
 
581
  print(f"text_tokens shape: {text_tokens.shape}, text_tokens type: {text_tokens.dtype}")
582
  # debug tokenizer
583
  text_token_syms = self.tokenizer.convert_ids_to_tokens(text_tokens[0].tolist())
584
+ print("text_token_syms is same as segment tokens", text_token_syms == sent)
585
 
586
  # text_len = torch.IntTensor([text_tokens.size(1)], device=text_tokens.device)
587
  # print(text_len)
588
  progress += 1
589
+ self._set_gr_progress(0.2 + 0.4 * (progress - 1) / len(segments),
590
+ f"gpt latents inference {progress}/{len(segments)}...")
591
  m_start_time = time.perf_counter()
592
  with torch.no_grad():
593
  with torch.amp.autocast(text_tokens.device.type, enabled=self.dtype is not None, dtype=self.dtype):
 
610
  warnings.warn(
611
  f"WARN: generation stopped due to exceeding `max_mel_tokens` ({max_mel_tokens}). "
612
  f"Input text tokens: {text_tokens.shape[1]}. "
613
+ f"Consider reducing `max_text_tokens_per_segment`({max_text_tokens_per_segment}) or increasing `max_mel_tokens`.",
614
  category=RuntimeWarning
615
  )
616
  has_warned = True
 
628
  print(codes, type(codes))
629
  print(f"fix codes shape: {codes.shape}, codes type: {codes.dtype}")
630
  print(f"code len: {code_lens}")
631
+ self._set_gr_progress(0.2 + 0.4 * progress / len(segments),
632
+ f"gpt speech inference {progress}/{len(segments)}...")
633
  m_start_time = time.perf_counter()
634
  # latent, text_lens_out, code_lens_out = \
635
  with torch.amp.autocast(text_tokens.device.type, enabled=self.dtype is not None, dtype=self.dtype):
 
653
  # wavs.append(wav[:, :-512])
654
  wavs.append(wav.cpu()) # to cpu before saving
655
  end_time = time.perf_counter()
656
+ self._set_gr_progress(0.9, "saving audio...")
657
  wav = torch.cat(wavs, dim=1)
658
  wav_length = wav.shape[-1] / sampling_rate
659
  print(f">> Reference audio length: {cond_mel_frame * 256 / sampling_rate:.2f} seconds")
indextts/infer_v2.py CHANGED
@@ -1,6 +1,9 @@
1
  import os
2
  from subprocess import CalledProcessError
3
 
 
 
 
4
  import time
5
  import librosa
6
  import torch
@@ -34,38 +37,43 @@ import torch.nn.functional as F
34
 
35
  class IndexTTS2:
36
  def __init__(
37
- self, cfg_path="checkpoints/config.yaml", model_dir="checkpoints", is_fp16=False, device=None,
38
  use_cuda_kernel=None,use_deepspeed=False
39
  ):
40
  """
41
  Args:
42
  cfg_path (str): path to the config file.
43
  model_dir (str): path to the model directory.
44
- is_fp16 (bool): whether to use fp16.
45
  device (str): device to use (e.g., 'cuda:0', 'cpu'). If None, it will be set automatically based on the availability of CUDA or MPS.
46
  use_cuda_kernel (None | bool): whether to use BigVGan custom fused activation CUDA kernel, only for CUDA device.
 
47
  """
48
  if device is not None:
49
  self.device = device
50
- self.is_fp16 = False if device == "cpu" else is_fp16
51
  self.use_cuda_kernel = use_cuda_kernel is not None and use_cuda_kernel and device.startswith("cuda")
52
  elif torch.cuda.is_available():
53
  self.device = "cuda:0"
54
- self.is_fp16 = is_fp16
55
  self.use_cuda_kernel = use_cuda_kernel is None or use_cuda_kernel
 
 
 
 
56
  elif hasattr(torch, "mps") and torch.backends.mps.is_available():
57
  self.device = "mps"
58
- self.is_fp16 = False # Use float16 on MPS is overhead than float32
59
  self.use_cuda_kernel = False
60
  else:
61
  self.device = "cpu"
62
- self.is_fp16 = False
63
  self.use_cuda_kernel = False
64
  print(">> Be patient, it may take a while to run in CPU mode.")
65
 
66
  self.cfg = OmegaConf.load(cfg_path)
67
  self.model_dir = model_dir
68
- self.dtype = torch.float16 if self.is_fp16 else None
69
  self.stop_mel_token = self.cfg.gpt.stop_mel_token
70
 
71
  self.qwen_emo = QwenEmotion(os.path.join(self.model_dir, self.cfg.qwen_emo_path))
@@ -74,32 +82,30 @@ class IndexTTS2:
74
  self.gpt_path = os.path.join(self.model_dir, self.cfg.gpt_checkpoint)
75
  load_checkpoint(self.gpt, self.gpt_path)
76
  self.gpt = self.gpt.to(self.device)
77
- if self.is_fp16:
78
  self.gpt.eval().half()
79
  else:
80
  self.gpt.eval()
81
  print(">> GPT weights restored from:", self.gpt_path)
82
- if self.is_fp16:
 
83
  try:
84
  import deepspeed
85
-
86
  except (ImportError, OSError, CalledProcessError) as e:
87
  use_deepspeed = False
88
- print(f">> DeepSpeed加载失败,回退到标准推理: {e}")
89
 
90
- self.gpt.post_init_gpt2_config(use_deepspeed=use_deepspeed, kv_cache=True, half=True)
91
- else:
92
- self.gpt.post_init_gpt2_config(use_deepspeed=use_deepspeed, kv_cache=True, half=False)
93
 
94
  if self.use_cuda_kernel:
95
  # preload the CUDA kernel for BigVGAN
96
  try:
97
- from indextts.BigVGAN.alias_free_activation.cuda import load
98
 
99
- anti_alias_activation_cuda = load.load()
100
- print(">> Preload custom CUDA kernel for BigVGAN", anti_alias_activation_cuda)
101
- except:
102
  print(">> Failed to load custom CUDA kernel for BigVGAN. Falling back to torch.")
 
103
  self.use_cuda_kernel = False
104
 
105
  self.extract_features = SeamlessM4TFeatureExtractor.from_pretrained("facebook/w2v-bert-2.0")
@@ -143,7 +149,7 @@ class IndexTTS2:
143
  print(">> campplus_model weights restored from:", campplus_ckpt_path)
144
 
145
  bigvgan_name = self.cfg.vocoder.name
146
- self.bigvgan = bigvgan.BigVGAN.from_pretrained(bigvgan_name, use_cuda_kernel=False)
147
  self.bigvgan = self.bigvgan.to(self.device)
148
  self.bigvgan.remove_weight_norm()
149
  self.bigvgan.eval()
@@ -261,7 +267,7 @@ class IndexTTS2:
261
 
262
  def insert_interval_silence(self, wavs, sampling_rate=22050, interval_silence=200):
263
  """
264
- Insert silences between sentences.
265
  wavs: List[torch.tensor]
266
  """
267
 
@@ -286,47 +292,69 @@ class IndexTTS2:
286
  if self.gr_progress is not None:
287
  self.gr_progress(value, desc=desc)
288
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
289
  # 原始推理模式
290
  def infer(self, spk_audio_prompt, text, output_path,
291
  emo_audio_prompt=None, emo_alpha=1.0,
292
  emo_vector=None,
293
  use_emo_text=False, emo_text=None, use_random=False, interval_silence=200,
294
- verbose=False, max_text_tokens_per_sentence=120, **generation_kwargs):
295
- print(">> start inference...")
296
- self._set_gr_progress(0, "start inference...")
297
  if verbose:
298
- print(f"origin text:{text}, spk_audio_prompt:{spk_audio_prompt},"
299
- f" emo_audio_prompt:{emo_audio_prompt}, emo_alpha:{emo_alpha}, "
300
  f"emo_vector:{emo_vector}, use_emo_text:{use_emo_text}, "
301
  f"emo_text:{emo_text}")
302
  start_time = time.perf_counter()
303
 
304
- if use_emo_text:
 
 
305
  emo_audio_prompt = None
306
- emo_alpha = 1.0
307
- # assert emo_audio_prompt is None
308
- # assert emo_alpha == 1.0
309
  if emo_text is None:
310
- emo_text = text
311
- emo_dict, content = self.qwen_emo.inference(emo_text)
312
- print(emo_dict)
 
313
  emo_vector = list(emo_dict.values())
314
 
315
  if emo_vector is not None:
316
- emo_audio_prompt = None
317
- emo_alpha = 1.0
318
- # assert emo_audio_prompt is None
319
- # assert emo_alpha == 1.0
 
 
 
 
320
 
321
  if emo_audio_prompt is None:
 
 
322
  emo_audio_prompt = spk_audio_prompt
 
323
  emo_alpha = 1.0
324
- # assert emo_alpha == 1.0
325
 
326
  # 如果参考音频改变了,才需要重新生成, 提升速度
327
  if self.cache_spk_cond is None or self.cache_spk_audio_prompt != spk_audio_prompt:
328
- audio, sr = librosa.load(spk_audio_prompt)
329
- audio = torch.tensor(audio).unsqueeze(0)
330
  audio_22k = torchaudio.transforms.Resample(sr, 22050)(audio)
331
  audio_16k = torchaudio.transforms.Resample(sr, 16000)(audio)
332
 
@@ -377,7 +405,7 @@ class IndexTTS2:
377
  emovec_mat = emovec_mat.unsqueeze(0)
378
 
379
  if self.cache_emo_cond is None or self.cache_emo_audio_prompt != emo_audio_prompt:
380
- emo_audio, _ = librosa.load(emo_audio_prompt, sr=16000)
381
  emo_inputs = self.extract_features(emo_audio, sampling_rate=16000, return_tensors="pt")
382
  emo_input_features = emo_inputs["input_features"]
383
  emo_attention_mask = emo_inputs["attention_mask"]
@@ -392,12 +420,13 @@ class IndexTTS2:
392
 
393
  self._set_gr_progress(0.1, "text processing...")
394
  text_tokens_list = self.tokenizer.tokenize(text)
395
- sentences = self.tokenizer.split_sentences(text_tokens_list, max_text_tokens_per_sentence)
 
396
  if verbose:
397
  print("text_tokens_list:", text_tokens_list)
398
- print("sentences count:", len(sentences))
399
- print("max_text_tokens_per_sentence:", max_text_tokens_per_sentence)
400
- print(*sentences, sep="\n")
401
  do_sample = generation_kwargs.pop("do_sample", True)
402
  top_p = generation_kwargs.pop("top_p", 0.8)
403
  top_k = generation_kwargs.pop("top_k", 30)
@@ -414,9 +443,11 @@ class IndexTTS2:
414
  gpt_forward_time = 0
415
  s2mel_time = 0
416
  bigvgan_time = 0
417
- progress = 0
418
  has_warned = False
419
- for sent in sentences:
 
 
 
420
  text_tokens = self.tokenizer.convert_tokens_to_ids(sent)
421
  text_tokens = torch.tensor(text_tokens, dtype=torch.int32, device=self.device).unsqueeze(0)
422
  if verbose:
@@ -424,7 +455,7 @@ class IndexTTS2:
424
  print(f"text_tokens shape: {text_tokens.shape}, text_tokens type: {text_tokens.dtype}")
425
  # debug tokenizer
426
  text_token_syms = self.tokenizer.convert_ids_to_tokens(text_tokens[0].tolist())
427
- print("text_token_syms is same as sentence tokens", text_token_syms == sent)
428
 
429
  m_start_time = time.perf_counter()
430
  with torch.no_grad():
@@ -465,7 +496,7 @@ class IndexTTS2:
465
  warnings.warn(
466
  f"WARN: generation stopped due to exceeding `max_mel_tokens` ({max_mel_tokens}). "
467
  f"Input text tokens: {text_tokens.shape[1]}. "
468
- f"Consider reducing `max_text_tokens_per_sentence`({max_text_tokens_per_sentence}) or increasing `max_mel_tokens`.",
469
  category=RuntimeWarning
470
  )
471
  has_warned = True
@@ -546,7 +577,8 @@ class IndexTTS2:
546
  # wavs.append(wav[:, :-512])
547
  wavs.append(wav.cpu()) # to cpu before saving
548
  end_time = time.perf_counter()
549
- self._set_gr_progress(0.9, "save audio...")
 
550
  wavs = self.insert_interval_silence(wavs, sampling_rate=sampling_rate, interval_silence=interval_silence)
551
  wav = torch.cat(wavs, dim=1)
552
  wav_length = wav.shape[-1] / sampling_rate
@@ -595,59 +627,52 @@ class QwenEmotion:
595
  device_map="auto"
596
  )
597
  self.prompt = "文本情感分类"
598
- self.convert_dict = {
599
- "愤怒": "angry",
600
  "高兴": "happy",
601
- "恐惧": "fear",
602
- "反感": "hate",
603
  "悲伤": "sad",
604
- "低落": "low",
605
- "惊讶": "surprise",
606
- "自然": "neutral",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
607
  }
608
- self.backup_dict = {"happy": 0, "angry": 0, "sad": 0, "fear": 0, "hate": 0, "low": 0, "surprise": 0,
609
- "neutral": 1.0}
610
  self.max_score = 1.2
611
  self.min_score = 0.0
612
 
 
 
 
613
  def convert(self, content):
614
- content = content.replace("\n", " ")
615
- content = content.replace(" ", "")
616
- content = content.replace("{", "")
617
- content = content.replace("}", "")
618
- content = content.replace('"', "")
619
- parts = content.strip().split(',')
620
- print(parts)
621
- parts_dict = {}
622
- desired_order = ["高兴", "愤怒", "悲伤", "恐惧", "反感", "低落", "惊讶", "自然"]
623
- for part in parts:
624
- key_value = part.strip().split(':')
625
- if len(key_value) == 2:
626
- parts_dict[key_value[0].strip()] = part
627
- # 按照期望顺序重新排列
628
- ordered_parts = [parts_dict[key] for key in desired_order if key in parts_dict]
629
- parts = ordered_parts
630
- if len(parts) != len(self.convert_dict):
631
- return self.backup_dict
632
-
633
- emotion_dict = {}
634
- for part in parts:
635
- key_value = part.strip().split(':')
636
- if len(key_value) == 2:
637
- try:
638
- key = self.convert_dict[key_value[0].strip()]
639
- value = float(key_value[1].strip())
640
- value = max(self.min_score, min(self.max_score, value))
641
- emotion_dict[key] = value
642
- except Exception:
643
- continue
644
-
645
- for key in self.backup_dict:
646
- if key not in emotion_dict:
647
- emotion_dict[key] = 0.0
648
-
649
- if sum(emotion_dict.values()) <= 0:
650
- return self.backup_dict
651
 
652
  return emotion_dict
653
 
@@ -680,9 +705,30 @@ class QwenEmotion:
680
  except ValueError:
681
  index = 0
682
 
683
- content = self.tokenizer.decode(output_ids[index:], skip_special_tokens=True).strip("\n")
684
- emotion_dict = self.convert(content)
685
- return emotion_dict, content
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
686
 
687
 
688
  if __name__ == "__main__":
 
1
  import os
2
  from subprocess import CalledProcessError
3
 
4
+ os.environ['HF_HUB_CACHE'] = './checkpoints/hf_cache'
5
+ import json
6
+ import re
7
  import time
8
  import librosa
9
  import torch
 
37
 
38
  class IndexTTS2:
39
  def __init__(
40
+ self, cfg_path="checkpoints/config.yaml", model_dir="checkpoints", use_fp16=False, device=None,
41
  use_cuda_kernel=None,use_deepspeed=False
42
  ):
43
  """
44
  Args:
45
  cfg_path (str): path to the config file.
46
  model_dir (str): path to the model directory.
47
+ use_fp16 (bool): whether to use fp16.
48
  device (str): device to use (e.g., 'cuda:0', 'cpu'). If None, it will be set automatically based on the availability of CUDA or MPS.
49
  use_cuda_kernel (None | bool): whether to use BigVGan custom fused activation CUDA kernel, only for CUDA device.
50
+ use_deepspeed (bool): whether to use DeepSpeed or not.
51
  """
52
  if device is not None:
53
  self.device = device
54
+ self.use_fp16 = False if device == "cpu" else use_fp16
55
  self.use_cuda_kernel = use_cuda_kernel is not None and use_cuda_kernel and device.startswith("cuda")
56
  elif torch.cuda.is_available():
57
  self.device = "cuda:0"
58
+ self.use_fp16 = use_fp16
59
  self.use_cuda_kernel = use_cuda_kernel is None or use_cuda_kernel
60
+ elif hasattr(torch, "xpu") and torch.xpu.is_available():
61
+ self.device = "xpu"
62
+ self.use_fp16 = use_fp16
63
+ self.use_cuda_kernel = False
64
  elif hasattr(torch, "mps") and torch.backends.mps.is_available():
65
  self.device = "mps"
66
+ self.use_fp16 = False # Use float16 on MPS is overhead than float32
67
  self.use_cuda_kernel = False
68
  else:
69
  self.device = "cpu"
70
+ self.use_fp16 = False
71
  self.use_cuda_kernel = False
72
  print(">> Be patient, it may take a while to run in CPU mode.")
73
 
74
  self.cfg = OmegaConf.load(cfg_path)
75
  self.model_dir = model_dir
76
+ self.dtype = torch.float16 if self.use_fp16 else None
77
  self.stop_mel_token = self.cfg.gpt.stop_mel_token
78
 
79
  self.qwen_emo = QwenEmotion(os.path.join(self.model_dir, self.cfg.qwen_emo_path))
 
82
  self.gpt_path = os.path.join(self.model_dir, self.cfg.gpt_checkpoint)
83
  load_checkpoint(self.gpt, self.gpt_path)
84
  self.gpt = self.gpt.to(self.device)
85
+ if self.use_fp16:
86
  self.gpt.eval().half()
87
  else:
88
  self.gpt.eval()
89
  print(">> GPT weights restored from:", self.gpt_path)
90
+
91
+ if use_deepspeed:
92
  try:
93
  import deepspeed
 
94
  except (ImportError, OSError, CalledProcessError) as e:
95
  use_deepspeed = False
96
+ print(f">> Failed to load DeepSpeed. Falling back to normal inference. Error: {e}")
97
 
98
+ self.gpt.post_init_gpt2_config(use_deepspeed=use_deepspeed, kv_cache=True, half=self.use_fp16)
 
 
99
 
100
  if self.use_cuda_kernel:
101
  # preload the CUDA kernel for BigVGAN
102
  try:
103
+ from indextts.s2mel.modules.bigvgan.alias_free_activation.cuda import activation1d
104
 
105
+ print(">> Preload custom CUDA kernel for BigVGAN", activation1d.anti_alias_activation_cuda)
106
+ except Exception as e:
 
107
  print(">> Failed to load custom CUDA kernel for BigVGAN. Falling back to torch.")
108
+ print(f"{e!r}")
109
  self.use_cuda_kernel = False
110
 
111
  self.extract_features = SeamlessM4TFeatureExtractor.from_pretrained("facebook/w2v-bert-2.0")
 
149
  print(">> campplus_model weights restored from:", campplus_ckpt_path)
150
 
151
  bigvgan_name = self.cfg.vocoder.name
152
+ self.bigvgan = bigvgan.BigVGAN.from_pretrained(bigvgan_name, use_cuda_kernel=self.use_cuda_kernel)
153
  self.bigvgan = self.bigvgan.to(self.device)
154
  self.bigvgan.remove_weight_norm()
155
  self.bigvgan.eval()
 
267
 
268
  def insert_interval_silence(self, wavs, sampling_rate=22050, interval_silence=200):
269
  """
270
+ Insert silences between generated segments.
271
  wavs: List[torch.tensor]
272
  """
273
 
 
292
  if self.gr_progress is not None:
293
  self.gr_progress(value, desc=desc)
294
 
295
+ def _load_and_cut_audio(self,audio_path,max_audio_length_seconds,verbose=False,sr=None):
296
+ if not sr:
297
+ audio, sr = librosa.load(audio_path)
298
+ else:
299
+ audio, _ = librosa.load(audio_path,sr=sr)
300
+ audio = torch.tensor(audio).unsqueeze(0)
301
+ max_audio_samples = int(max_audio_length_seconds * sr)
302
+
303
+ if audio.shape[1] > max_audio_samples:
304
+ if verbose:
305
+ print(f"Audio too long ({audio.shape[1]} samples), truncating to {max_audio_samples} samples")
306
+ audio = audio[:, :max_audio_samples]
307
+ return audio, sr
308
+
309
  # 原始推理模式
310
  def infer(self, spk_audio_prompt, text, output_path,
311
  emo_audio_prompt=None, emo_alpha=1.0,
312
  emo_vector=None,
313
  use_emo_text=False, emo_text=None, use_random=False, interval_silence=200,
314
+ verbose=False, max_text_tokens_per_segment=120, **generation_kwargs):
315
+ print(">> starting inference...")
316
+ self._set_gr_progress(0, "starting inference...")
317
  if verbose:
318
+ print(f"origin text:{text}, spk_audio_prompt:{spk_audio_prompt}, "
319
+ f"emo_audio_prompt:{emo_audio_prompt}, emo_alpha:{emo_alpha}, "
320
  f"emo_vector:{emo_vector}, use_emo_text:{use_emo_text}, "
321
  f"emo_text:{emo_text}")
322
  start_time = time.perf_counter()
323
 
324
+ if use_emo_text or emo_vector is not None:
325
+ # we're using a text or emotion vector guidance; so we must remove
326
+ # "emotion reference voice", to ensure we use correct emotion mixing!
327
  emo_audio_prompt = None
328
+
329
+ if use_emo_text:
330
+ # automatically generate emotion vectors from text prompt
331
  if emo_text is None:
332
+ emo_text = text # use main text prompt
333
+ emo_dict = self.qwen_emo.inference(emo_text)
334
+ print(f"detected emotion vectors from text: {emo_dict}")
335
+ # convert ordered dict to list of vectors; the order is VERY important!
336
  emo_vector = list(emo_dict.values())
337
 
338
  if emo_vector is not None:
339
+ # we have emotion vectors; they can't be blended via alpha mixing
340
+ # in the main inference process later, so we must pre-calculate
341
+ # their new strengths here based on the alpha instead!
342
+ emo_vector_scale = max(0.0, min(1.0, emo_alpha))
343
+ if emo_vector_scale != 1.0:
344
+ # scale each vector and truncate to 4 decimals (for nicer printing)
345
+ emo_vector = [int(x * emo_vector_scale * 10000) / 10000 for x in emo_vector]
346
+ print(f"scaled emotion vectors to {emo_vector_scale}x: {emo_vector}")
347
 
348
  if emo_audio_prompt is None:
349
+ # we are not using any external "emotion reference voice"; use
350
+ # speaker's voice as the main emotion reference audio.
351
  emo_audio_prompt = spk_audio_prompt
352
+ # must always use alpha=1.0 when we don't have an external reference voice
353
  emo_alpha = 1.0
 
354
 
355
  # 如果参考音频改变了,才需要重新生成, 提升速度
356
  if self.cache_spk_cond is None or self.cache_spk_audio_prompt != spk_audio_prompt:
357
+ audio,sr = self._load_and_cut_audio(spk_audio_prompt,15,verbose)
 
358
  audio_22k = torchaudio.transforms.Resample(sr, 22050)(audio)
359
  audio_16k = torchaudio.transforms.Resample(sr, 16000)(audio)
360
 
 
405
  emovec_mat = emovec_mat.unsqueeze(0)
406
 
407
  if self.cache_emo_cond is None or self.cache_emo_audio_prompt != emo_audio_prompt:
408
+ emo_audio, _ = self._load_and_cut_audio(emo_audio_prompt,15,verbose,sr=16000)
409
  emo_inputs = self.extract_features(emo_audio, sampling_rate=16000, return_tensors="pt")
410
  emo_input_features = emo_inputs["input_features"]
411
  emo_attention_mask = emo_inputs["attention_mask"]
 
420
 
421
  self._set_gr_progress(0.1, "text processing...")
422
  text_tokens_list = self.tokenizer.tokenize(text)
423
+ segments = self.tokenizer.split_segments(text_tokens_list, max_text_tokens_per_segment)
424
+ segments_count = len(segments)
425
  if verbose:
426
  print("text_tokens_list:", text_tokens_list)
427
+ print("segments count:", segments_count)
428
+ print("max_text_tokens_per_segment:", max_text_tokens_per_segment)
429
+ print(*segments, sep="\n")
430
  do_sample = generation_kwargs.pop("do_sample", True)
431
  top_p = generation_kwargs.pop("top_p", 0.8)
432
  top_k = generation_kwargs.pop("top_k", 30)
 
443
  gpt_forward_time = 0
444
  s2mel_time = 0
445
  bigvgan_time = 0
 
446
  has_warned = False
447
+ for seg_idx, sent in enumerate(segments):
448
+ self._set_gr_progress(0.2 + 0.7 * seg_idx / segments_count,
449
+ f"speech synthesis {seg_idx + 1}/{segments_count}...")
450
+
451
  text_tokens = self.tokenizer.convert_tokens_to_ids(sent)
452
  text_tokens = torch.tensor(text_tokens, dtype=torch.int32, device=self.device).unsqueeze(0)
453
  if verbose:
 
455
  print(f"text_tokens shape: {text_tokens.shape}, text_tokens type: {text_tokens.dtype}")
456
  # debug tokenizer
457
  text_token_syms = self.tokenizer.convert_ids_to_tokens(text_tokens[0].tolist())
458
+ print("text_token_syms is same as segment tokens", text_token_syms == sent)
459
 
460
  m_start_time = time.perf_counter()
461
  with torch.no_grad():
 
496
  warnings.warn(
497
  f"WARN: generation stopped due to exceeding `max_mel_tokens` ({max_mel_tokens}). "
498
  f"Input text tokens: {text_tokens.shape[1]}. "
499
+ f"Consider reducing `max_text_tokens_per_segment`({max_text_tokens_per_segment}) or increasing `max_mel_tokens`.",
500
  category=RuntimeWarning
501
  )
502
  has_warned = True
 
577
  # wavs.append(wav[:, :-512])
578
  wavs.append(wav.cpu()) # to cpu before saving
579
  end_time = time.perf_counter()
580
+
581
+ self._set_gr_progress(0.9, "saving audio...")
582
  wavs = self.insert_interval_silence(wavs, sampling_rate=sampling_rate, interval_silence=interval_silence)
583
  wav = torch.cat(wavs, dim=1)
584
  wav_length = wav.shape[-1] / sampling_rate
 
627
  device_map="auto"
628
  )
629
  self.prompt = "文本情感分类"
630
+ self.cn_key_to_en = {
 
631
  "高兴": "happy",
632
+ "愤怒": "angry",
 
633
  "悲伤": "sad",
634
+ "恐惧": "afraid",
635
+ "反感": "disgusted",
636
+ # TODO: the "低落" (melancholic) emotion will always be mapped to
637
+ # "悲伤" (sad) by QwenEmotion's text analysis. it doesn't know the
638
+ # difference between those emotions even if user writes exact words.
639
+ # SEE: `self.melancholic_words` for current workaround.
640
+ "低落": "melancholic",
641
+ "惊讶": "surprised",
642
+ "自然": "calm",
643
+ }
644
+ self.desired_vector_order = ["高兴", "愤怒", "悲伤", "恐惧", "反感", "低落", "惊讶", "自然"]
645
+ self.melancholic_words = {
646
+ # emotion text phrases that will force QwenEmotion's "悲伤" (sad) detection
647
+ # to become "低落" (melancholic) instead, to fix limitations mentioned above.
648
+ "低落",
649
+ "melancholy",
650
+ "melancholic",
651
+ "depression",
652
+ "depressed",
653
+ "gloomy",
654
  }
 
 
655
  self.max_score = 1.2
656
  self.min_score = 0.0
657
 
658
+ def clamp_score(self, value):
659
+ return max(self.min_score, min(self.max_score, value))
660
+
661
  def convert(self, content):
662
+ # generate emotion vector dictionary:
663
+ # - insert values in desired order (Python 3.7+ `dict` remembers insertion order)
664
+ # - convert Chinese keys to English
665
+ # - clamp all values to the allowed min/max range
666
+ # - use 0.0 for any values that were missing in `content`
667
+ emotion_dict = {
668
+ self.cn_key_to_en[cn_key]: self.clamp_score(content.get(cn_key, 0.0))
669
+ for cn_key in self.desired_vector_order
670
+ }
671
+
672
+ # default to a calm/neutral voice if all emotion vectors were empty
673
+ if all(val <= 0.0 for val in emotion_dict.values()):
674
+ print(">> no emotions detected; using default calm/neutral voice")
675
+ emotion_dict["calm"] = 1.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
676
 
677
  return emotion_dict
678
 
 
705
  except ValueError:
706
  index = 0
707
 
708
+ content = self.tokenizer.decode(output_ids[index:], skip_special_tokens=True)
709
+
710
+ # decode the JSON emotion detections as a dictionary
711
+ try:
712
+ content = json.loads(content)
713
+ except json.decoder.JSONDecodeError:
714
+ # invalid JSON; fallback to manual string parsing
715
+ # print(">> parsing QwenEmotion response", content)
716
+ content = {
717
+ m.group(1): float(m.group(2))
718
+ for m in re.finditer(r'([^\s":.,]+?)"?\s*:\s*([\d.]+)', content)
719
+ }
720
+ # print(">> dict result", content)
721
+
722
+ # workaround for QwenEmotion's inability to distinguish "悲伤" (sad) vs "低落" (melancholic).
723
+ # if we detect any of the IndexTTS "melancholic" words, we swap those vectors
724
+ # to encode the "sad" emotion as "melancholic" (instead of sadness).
725
+ text_input_lower = text_input.lower()
726
+ if any(word in text_input_lower for word in self.melancholic_words):
727
+ # print(">> before vec swap", content)
728
+ content["悲伤"], content["低落"] = content.get("低落", 0.0), content.get("悲伤", 0.0)
729
+ # print(">> after vec swap", content)
730
+
731
+ return self.convert(content)
732
 
733
 
734
  if __name__ == "__main__":
indextts/s2mel/modules/openvoice/api.py CHANGED
@@ -63,9 +63,9 @@ class BaseSpeakerTTS(OpenVoiceBaseClass):
63
  return audio_segments
64
 
65
  @staticmethod
66
- def split_sentences_into_pieces(text, language_str):
67
- texts = utils.split_sentence(text, language_str=language_str)
68
- print(" > Text splitted to sentences.")
69
  print('\n'.join(texts))
70
  print(" > ===========================")
71
  return texts
@@ -74,7 +74,7 @@ class BaseSpeakerTTS(OpenVoiceBaseClass):
74
  mark = self.language_marks.get(language.lower(), None)
75
  assert mark is not None, f"language {language} is not supported"
76
 
77
- texts = self.split_sentences_into_pieces(text, mark)
78
 
79
  audio_list = []
80
  for t in texts:
 
63
  return audio_segments
64
 
65
  @staticmethod
66
+ def split_segments_into_pieces(text, language_str):
67
+ texts = utils.split_segment(text, language_str=language_str)
68
+ print(" > Text split into segments.")
69
  print('\n'.join(texts))
70
  print(" > ===========================")
71
  return texts
 
74
  mark = self.language_marks.get(language.lower(), None)
75
  assert mark is not None, f"language {language} is not supported"
76
 
77
+ texts = self.split_segments_into_pieces(text, mark)
78
 
79
  audio_list = []
80
  for t in texts:
indextts/s2mel/modules/openvoice/openvoice_app.py CHANGED
@@ -233,7 +233,7 @@ with gr.Blocks(analytics_enabled=False) as demo:
233
  with gr.Column():
234
  input_text_gr = gr.Textbox(
235
  label="Text Prompt",
236
- info="One or two sentences at a time is better. Up to 200 text characters.",
237
  value="He hoped there would be stew for dinner, turnips and carrots and bruised potatoes and fat mutton pieces to be ladled out in thick, peppered, flour-fattened sauce.",
238
  )
239
  style_gr = gr.Dropdown(
 
233
  with gr.Column():
234
  input_text_gr = gr.Textbox(
235
  label="Text Prompt",
236
+ info="One or two sentences at a time produces the best results. Up to 200 text characters.",
237
  value="He hoped there would be stew for dinner, turnips and carrots and bruised potatoes and fat mutton pieces to be ladled out in thick, peppered, flour-fattened sauce.",
238
  )
239
  style_gr = gr.Dropdown(
indextts/s2mel/modules/openvoice/utils.py CHANGED
@@ -75,23 +75,23 @@ def bits_to_string(bits_array):
75
  return output_string
76
 
77
 
78
- def split_sentence(text, min_len=10, language_str='[EN]'):
79
  if language_str in ['EN']:
80
- sentences = split_sentences_latin(text, min_len=min_len)
81
  else:
82
- sentences = split_sentences_zh(text, min_len=min_len)
83
- return sentences
84
 
85
- def split_sentences_latin(text, min_len=10):
86
- """Split Long sentences into list of short ones
87
 
88
  Args:
89
  str: Input sentences.
90
 
91
  Returns:
92
- List[str]: list of output sentences.
93
  """
94
- # deal with dirty sentences
95
  text = re.sub('[。!?;]', '.', text)
96
  text = re.sub('[,]', ',', text)
97
  text = re.sub('[“”]', '"', text)
@@ -100,36 +100,36 @@ def split_sentences_latin(text, min_len=10):
100
  text = re.sub('[\n\t ]+', ' ', text)
101
  text = re.sub('([,.!?;])', r'\1 $#!', text)
102
  # split
103
- sentences = [s.strip() for s in text.split('$#!')]
104
- if len(sentences[-1]) == 0: del sentences[-1]
105
 
106
- new_sentences = []
107
  new_sent = []
108
  count_len = 0
109
- for ind, sent in enumerate(sentences):
110
  # print(sent)
111
  new_sent.append(sent)
112
  count_len += len(sent.split(" "))
113
- if count_len > min_len or ind == len(sentences) - 1:
114
  count_len = 0
115
- new_sentences.append(' '.join(new_sent))
116
  new_sent = []
117
- return merge_short_sentences_latin(new_sentences)
118
 
119
 
120
- def merge_short_sentences_latin(sens):
121
- """Avoid short sentences by merging them with the following sentence.
122
 
123
  Args:
124
- List[str]: list of input sentences.
125
 
126
  Returns:
127
- List[str]: list of output sentences.
128
  """
129
  sens_out = []
130
  for s in sens:
131
- # If the previous sentence is too short, merge them with
132
- # the current sentence.
133
  if len(sens_out) > 0 and len(sens_out[-1].split(" ")) <= 2:
134
  sens_out[-1] = sens_out[-1] + " " + s
135
  else:
@@ -142,7 +142,7 @@ def merge_short_sentences_latin(sens):
142
  pass
143
  return sens_out
144
 
145
- def split_sentences_zh(text, min_len=10):
146
  text = re.sub('[。!?;]', '.', text)
147
  text = re.sub('[,]', ',', text)
148
  # 将文本中的换行符、空格和制表符替换为空格
@@ -150,37 +150,37 @@ def split_sentences_zh(text, min_len=10):
150
  # 在标点符号后添加一个空格
151
  text = re.sub('([,.!?;])', r'\1 $#!', text)
152
  # 分隔句子并去除前后空格
153
- # sentences = [s.strip() for s in re.split('(。|!|?|;)', text)]
154
- sentences = [s.strip() for s in text.split('$#!')]
155
- if len(sentences[-1]) == 0: del sentences[-1]
156
 
157
- new_sentences = []
158
  new_sent = []
159
  count_len = 0
160
- for ind, sent in enumerate(sentences):
161
  new_sent.append(sent)
162
  count_len += len(sent)
163
- if count_len > min_len or ind == len(sentences) - 1:
164
  count_len = 0
165
- new_sentences.append(' '.join(new_sent))
166
  new_sent = []
167
- return merge_short_sentences_zh(new_sentences)
168
 
169
 
170
- def merge_short_sentences_zh(sens):
171
  # return sens
172
- """Avoid short sentences by merging them with the following sentence.
173
 
174
  Args:
175
- List[str]: list of input sentences.
176
 
177
  Returns:
178
- List[str]: list of output sentences.
179
  """
180
  sens_out = []
181
  for s in sens:
182
  # If the previous sentense is too short, merge them with
183
- # the current sentence.
184
  if len(sens_out) > 0 and len(sens_out[-1]) <= 2:
185
  sens_out[-1] = sens_out[-1] + " " + s
186
  else:
 
75
  return output_string
76
 
77
 
78
+ def split_segment(text, min_len=10, language_str='[EN]'):
79
  if language_str in ['EN']:
80
+ segments = split_segments_latin(text, min_len=min_len)
81
  else:
82
+ segments = split_segments_zh(text, min_len=min_len)
83
+ return segments
84
 
85
+ def split_segments_latin(text, min_len=10):
86
+ """Split Long sentences into list of short segments.
87
 
88
  Args:
89
  str: Input sentences.
90
 
91
  Returns:
92
+ List[str]: list of output segments.
93
  """
94
+ # deal with dirty text characters
95
  text = re.sub('[。!?;]', '.', text)
96
  text = re.sub('[,]', ',', text)
97
  text = re.sub('[“”]', '"', text)
 
100
  text = re.sub('[\n\t ]+', ' ', text)
101
  text = re.sub('([,.!?;])', r'\1 $#!', text)
102
  # split
103
+ segments = [s.strip() for s in text.split('$#!')]
104
+ if len(segments[-1]) == 0: del segments[-1]
105
 
106
+ new_segments = []
107
  new_sent = []
108
  count_len = 0
109
+ for ind, sent in enumerate(segments):
110
  # print(sent)
111
  new_sent.append(sent)
112
  count_len += len(sent.split(" "))
113
+ if count_len > min_len or ind == len(segments) - 1:
114
  count_len = 0
115
+ new_segments.append(' '.join(new_sent))
116
  new_sent = []
117
+ return merge_short_segments_latin(new_segments)
118
 
119
 
120
+ def merge_short_segments_latin(sens):
121
+ """Avoid short segments by merging them with the following segment.
122
 
123
  Args:
124
+ List[str]: list of input segments.
125
 
126
  Returns:
127
+ List[str]: list of output segments.
128
  """
129
  sens_out = []
130
  for s in sens:
131
+ # If the previous segment is too short, merge them with
132
+ # the current segment.
133
  if len(sens_out) > 0 and len(sens_out[-1].split(" ")) <= 2:
134
  sens_out[-1] = sens_out[-1] + " " + s
135
  else:
 
142
  pass
143
  return sens_out
144
 
145
+ def split_segments_zh(text, min_len=10):
146
  text = re.sub('[。!?;]', '.', text)
147
  text = re.sub('[,]', ',', text)
148
  # 将文本中的换行符、空格和制表符替换为空格
 
150
  # 在标点符号后添加一个空格
151
  text = re.sub('([,.!?;])', r'\1 $#!', text)
152
  # 分隔句子并去除前后空格
153
+ # segments = [s.strip() for s in re.split('(。|!|?|;)', text)]
154
+ segments = [s.strip() for s in text.split('$#!')]
155
+ if len(segments[-1]) == 0: del segments[-1]
156
 
157
+ new_segments = []
158
  new_sent = []
159
  count_len = 0
160
+ for ind, sent in enumerate(segments):
161
  new_sent.append(sent)
162
  count_len += len(sent)
163
+ if count_len > min_len or ind == len(segments) - 1:
164
  count_len = 0
165
+ new_segments.append(' '.join(new_sent))
166
  new_sent = []
167
+ return merge_short_segments_zh(new_segments)
168
 
169
 
170
+ def merge_short_segments_zh(sens):
171
  # return sens
172
+ """Avoid short segments by merging them with the following segment.
173
 
174
  Args:
175
+ List[str]: list of input segments.
176
 
177
  Returns:
178
+ List[str]: list of output segments.
179
  """
180
  sens_out = []
181
  for s in sens:
182
  # If the previous sentense is too short, merge them with
183
+ # the current segment.
184
  if len(sens_out) > 0 and len(sens_out[-1]) <= 2:
185
  sens_out[-1] = sens_out[-1] + " " + s
186
  else:
indextts/utils/front.py CHANGED
@@ -91,7 +91,7 @@ class TextNormalizer:
91
  import platform
92
  if self.zh_normalizer is not None and self.en_normalizer is not None:
93
  return
94
- if platform.system() == "Darwin":
95
  from wetext import Normalizer
96
 
97
  self.zh_normalizer = Normalizer(remove_erhua=False, lang="zh", operator="tn")
@@ -342,8 +342,8 @@ class TextTokenizer:
342
  return de_tokenized_by_CJK_char(decoded, do_lower_case=do_lower_case)
343
 
344
  @staticmethod
345
- def split_sentences_by_token(
346
- tokenized_str: List[str], split_tokens: List[str], max_tokens_per_sentence: int
347
  ) -> List[List[str]]:
348
  """
349
  将tokenize后的结果按特定token进一步分割
@@ -351,67 +351,67 @@ class TextTokenizer:
351
  # 处理特殊情况
352
  if len(tokenized_str) == 0:
353
  return []
354
- sentences: List[List[str]] = []
355
- current_sentence = []
356
- current_sentence_tokens_len = 0
357
  for i in range(len(tokenized_str)):
358
  token = tokenized_str[i]
359
- current_sentence.append(token)
360
- current_sentence_tokens_len += 1
361
- if current_sentence_tokens_len <= max_tokens_per_sentence:
362
- if token in split_tokens and current_sentence_tokens_len > 2:
363
  if i < len(tokenized_str) - 1:
364
  if tokenized_str[i + 1] in ["'", "▁'"]:
365
  # 后续token是',则不切分
366
- current_sentence.append(tokenized_str[i + 1])
367
  i += 1
368
- sentences.append(current_sentence)
369
- current_sentence = []
370
- current_sentence_tokens_len = 0
371
  continue
372
  # 如果当前tokens的长度超过最大限制
373
- if not ("," in split_tokens or "▁," in split_tokens ) and ("," in current_sentence or "▁," in current_sentence):
374
  # 如果当前tokens中有,,则按,分割
375
- sub_sentences = TextTokenizer.split_sentences_by_token(
376
- current_sentence, [",", "▁,"], max_tokens_per_sentence=max_tokens_per_sentence
377
  )
378
- elif "-" not in split_tokens and "-" in current_sentence:
379
  # 没有,,则按-分割
380
- sub_sentences = TextTokenizer.split_sentences_by_token(
381
- current_sentence, ["-"], max_tokens_per_sentence=max_tokens_per_sentence
382
  )
383
  else:
384
  # 按照长度分割
385
- sub_sentences = []
386
- for j in range(0, len(current_sentence), max_tokens_per_sentence):
387
- if j + max_tokens_per_sentence < len(current_sentence):
388
- sub_sentences.append(current_sentence[j : j + max_tokens_per_sentence])
389
  else:
390
- sub_sentences.append(current_sentence[j:])
391
  warnings.warn(
392
- f"The tokens length of sentence exceeds limit: {max_tokens_per_sentence}, "
393
- f"Tokens in sentence: {current_sentence}."
394
  "Maybe unexpected behavior",
395
  RuntimeWarning,
396
  )
397
- sentences.extend(sub_sentences)
398
- current_sentence = []
399
- current_sentence_tokens_len = 0
400
- if current_sentence_tokens_len > 0:
401
- assert current_sentence_tokens_len <= max_tokens_per_sentence
402
- sentences.append(current_sentence)
403
  # 如果相邻的句子加起来长度小于最大限制,则合并
404
- merged_sentences = []
405
- for sentence in sentences:
406
- if len(sentence) == 0:
407
  continue
408
- if len(merged_sentences) == 0:
409
- merged_sentences.append(sentence)
410
- elif len(merged_sentences[-1]) + len(sentence) <= max_tokens_per_sentence:
411
- merged_sentences[-1] = merged_sentences[-1] + sentence
412
  else:
413
- merged_sentences.append(sentence)
414
- return merged_sentences
415
 
416
  punctuation_marks_tokens = [
417
  ".",
@@ -422,9 +422,9 @@ class TextTokenizer:
422
  "▁?",
423
  "▁...", # ellipsis
424
  ]
425
- def split_sentences(self, tokenized: List[str], max_tokens_per_sentence=120) -> List[List[str]]:
426
- return TextTokenizer.split_sentences_by_token(
427
- tokenized, self.punctuation_marks_tokens, max_tokens_per_sentence=max_tokens_per_sentence
428
  )
429
 
430
 
@@ -516,19 +516,19 @@ if __name__ == "__main__":
516
  # 测试 normalize后的字符能被分词器识别
517
  print(f"`{ch}`", "->", tokenizer.sp_model.Encode(ch, out_type=str))
518
  print(f"` {ch}`", "->", tokenizer.sp_model.Encode(f" {ch}", out_type=str))
519
- max_tokens_per_sentence=120
520
  for i in range(len(cases)):
521
  print(f"原始文本: {cases[i]}")
522
  print(f"Normalized: {text_normalizer.normalize(cases[i])}")
523
  tokens = tokenizer.tokenize(cases[i])
524
  print("Tokenzied: ", ", ".join([f"`{t}`" for t in tokens]))
525
- sentences = tokenizer.split_sentences(tokens, max_tokens_per_sentence=max_tokens_per_sentence)
526
- print("Splitted sentences count:", len(sentences))
527
- if len(sentences) > 1:
528
- for j in range(len(sentences)):
529
- print(f" {j}, count:", len(sentences[j]), ", tokens:", "".join(sentences[j]))
530
- if len(sentences[j]) > max_tokens_per_sentence:
531
- print(f"Warning: sentence {j} is too long, length: {len(sentences[j])}")
532
  #print(f"Token IDs (first 10): {codes[i][:10]}")
533
  if tokenizer.unk_token in codes[i]:
534
  print(f"Warning: `{cases[i]}` contains UNKNOWN token")
 
91
  import platform
92
  if self.zh_normalizer is not None and self.en_normalizer is not None:
93
  return
94
+ if platform.system() != "Linux": # Mac and Windows
95
  from wetext import Normalizer
96
 
97
  self.zh_normalizer = Normalizer(remove_erhua=False, lang="zh", operator="tn")
 
342
  return de_tokenized_by_CJK_char(decoded, do_lower_case=do_lower_case)
343
 
344
  @staticmethod
345
+ def split_segments_by_token(
346
+ tokenized_str: List[str], split_tokens: List[str], max_text_tokens_per_segment: int
347
  ) -> List[List[str]]:
348
  """
349
  将tokenize后的结果按特定token进一步分割
 
351
  # 处理特殊情况
352
  if len(tokenized_str) == 0:
353
  return []
354
+ segments: List[List[str]] = []
355
+ current_segment = []
356
+ current_segment_tokens_len = 0
357
  for i in range(len(tokenized_str)):
358
  token = tokenized_str[i]
359
+ current_segment.append(token)
360
+ current_segment_tokens_len += 1
361
+ if current_segment_tokens_len <= max_text_tokens_per_segment:
362
+ if token in split_tokens and current_segment_tokens_len > 2:
363
  if i < len(tokenized_str) - 1:
364
  if tokenized_str[i + 1] in ["'", "▁'"]:
365
  # 后续token是',则不切分
366
+ current_segment.append(tokenized_str[i + 1])
367
  i += 1
368
+ segments.append(current_segment)
369
+ current_segment = []
370
+ current_segment_tokens_len = 0
371
  continue
372
  # 如果当前tokens的长度超过最大限制
373
+ if not ("," in split_tokens or "▁," in split_tokens ) and ("," in current_segment or "▁," in current_segment):
374
  # 如果当前tokens中有,,则按,分割
375
+ sub_segments = TextTokenizer.split_segments_by_token(
376
+ current_segment, [",", "▁,"], max_text_tokens_per_segment=max_text_tokens_per_segment
377
  )
378
+ elif "-" not in split_tokens and "-" in current_segment:
379
  # 没有,,则按-分割
380
+ sub_segments = TextTokenizer.split_segments_by_token(
381
+ current_segment, ["-"], max_text_tokens_per_segment=max_text_tokens_per_segment
382
  )
383
  else:
384
  # 按照长度分割
385
+ sub_segments = []
386
+ for j in range(0, len(current_segment), max_text_tokens_per_segment):
387
+ if j + max_text_tokens_per_segment < len(current_segment):
388
+ sub_segments.append(current_segment[j : j + max_text_tokens_per_segment])
389
  else:
390
+ sub_segments.append(current_segment[j:])
391
  warnings.warn(
392
+ f"The tokens length of segment exceeds limit: {max_text_tokens_per_segment}, "
393
+ f"Tokens in segment: {current_segment}."
394
  "Maybe unexpected behavior",
395
  RuntimeWarning,
396
  )
397
+ segments.extend(sub_segments)
398
+ current_segment = []
399
+ current_segment_tokens_len = 0
400
+ if current_segment_tokens_len > 0:
401
+ assert current_segment_tokens_len <= max_text_tokens_per_segment
402
+ segments.append(current_segment)
403
  # 如果相邻的句子加起来长度小于最大限制,则合并
404
+ merged_segments = []
405
+ for segment in segments:
406
+ if len(segment) == 0:
407
  continue
408
+ if len(merged_segments) == 0:
409
+ merged_segments.append(segment)
410
+ elif len(merged_segments[-1]) + len(segment) <= max_text_tokens_per_segment:
411
+ merged_segments[-1] = merged_segments[-1] + segment
412
  else:
413
+ merged_segments.append(segment)
414
+ return merged_segments
415
 
416
  punctuation_marks_tokens = [
417
  ".",
 
422
  "▁?",
423
  "▁...", # ellipsis
424
  ]
425
+ def split_segments(self, tokenized: List[str], max_text_tokens_per_segment=120) -> List[List[str]]:
426
+ return TextTokenizer.split_segments_by_token(
427
+ tokenized, self.punctuation_marks_tokens, max_text_tokens_per_segment=max_text_tokens_per_segment
428
  )
429
 
430
 
 
516
  # 测试 normalize后的字符能被分词器识别
517
  print(f"`{ch}`", "->", tokenizer.sp_model.Encode(ch, out_type=str))
518
  print(f"` {ch}`", "->", tokenizer.sp_model.Encode(f" {ch}", out_type=str))
519
+ max_text_tokens_per_segment=120
520
  for i in range(len(cases)):
521
  print(f"原始文本: {cases[i]}")
522
  print(f"Normalized: {text_normalizer.normalize(cases[i])}")
523
  tokens = tokenizer.tokenize(cases[i])
524
  print("Tokenzied: ", ", ".join([f"`{t}`" for t in tokens]))
525
+ segments = tokenizer.split_segments(tokens, max_text_tokens_per_segment=max_text_tokens_per_segment)
526
+ print("Segments count:", len(segments))
527
+ if len(segments) > 1:
528
+ for j in range(len(segments)):
529
+ print(f" {j}, count:", len(segments[j]), ", tokens:", "".join(segments[j]))
530
+ if len(segments[j]) > max_text_tokens_per_segment:
531
+ print(f"Warning: segment {j} is too long, length: {len(segments[j])}")
532
  #print(f"Token IDs (first 10): {codes[i][:10]}")
533
  if tokenizer.unk_token in codes[i]:
534
  print(f"Warning: `{cases[i]}` contains UNKNOWN token")
indextts/utils/maskgct/models/codec/facodec/modules/JDC/bst.t7 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:54dc94364b97e18ac1dfa6287714ed121248cfaac4cfd39d061c6e0a089ef169
3
+ size 21029926
tools/i18n/locale/en_US.json CHANGED
@@ -1,46 +1,49 @@
1
  {
2
- "本软件以MIT协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责.": "This software is open-sourced under the MIT License. The author has no control over the software, and users of the software, as well as those who distribute the audio generated by the software, assume full responsibility.",
3
- "如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录LICENSE.": "If you do not agree to these terms, you are not permitted to use or reference any code or files within the software package. For further details, please refer to the LICENSE file in the root directory.",
4
  "时长必须为正数": "Duration must be a positive number",
5
  "请输入有效的浮点数": "Please enter a valid floating-point number",
6
  "使用情感参考音频": "Use emotion reference audio",
7
- "使用情感向量控制": "Use emotion vector",
8
  "使用情感描述文本控制": "Use text description to control emotion",
9
  "上传情感参考音频": "Upload emotion reference audio",
10
  "情感权重": "Emotion control weight",
11
  "喜": "Happy",
12
  "怒": "Angry",
13
  "哀": "Sad",
14
- "惧": "Fear",
15
- "厌恶": "Hate",
16
- "低落": "Low",
17
- "惊喜": "Surprise",
18
- "平静": "Neutral",
19
  "情感描述文本": "Emotion description",
20
- "请输入情感描述文本": "Please input emotion description",
21
  "高级生成参数设置": "Advanced generation parameter settings",
22
  "情感向量之和不能超过1.5,请调整后重试。": "The sum of the emotion vectors cannot exceed 1.5. Please adjust and try again.",
23
- "音色参考音频": "Voice reference",
24
  "音频生成": "Speech Synthesis",
25
  "文本": "Text",
26
  "生成语音": "Synthesize",
27
  "生成结果": "Synthesis Result",
28
  "功能设置": "Settings",
29
- "分句设置": "Sentence segmentation settings",
30
- "参数会影响音频质量和生成速度": "Parameters below affect audio quality and generation speed",
31
- "分句最大Token数": "Max tokens per sentence",
32
- "建议80~200之间,值越大,分句越长;值越小,分句越碎;过小过大都可能导致音频质量不高": "Recommended between 80 and 200. The larger the value, the longer the sentences; the smaller the value, the more fragmented the sentences. Values that are too small or too large may lead to poor audio quality.",
33
- "预览分句结果": "Preview sentence segmentation result",
34
  "序号": "Index",
35
  "分句内容": "Content",
36
  "Token数": "Token Count",
37
  "情感控制方式": "Emotion control method",
38
  "GPT2 采样设置": "GPT-2 Sampling Configuration",
39
- "参数会影响音频多样性和生成速度详见": "Influence both the diversity of the generated audio and the generation speed. For further details, refer to",
40
- "请上传情感参考音频": "Please upload emotion reference audio",
41
- "当前模型版本": "Current model version ",
42
- "请输入目标文本": "Please input text to synthesize",
43
- "例如:高兴,愤怒,悲伤等": "e.g., happy, angry, sad, etc.",
 
 
44
  "与音色参考音频相同": "Same as the voice reference",
45
- "情感随机采样": "Random emotion sampling"
 
46
  }
 
1
  {
2
+ "本软件以自拟协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责.": "This software is open-sourced under customized license. The author has no control over the software, and users of the software, as well as those who distribute the audio generated by the software, assume full responsibility.",
3
+ "如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录LICENSE.": "If you do not agree to these terms, you are not permitted to use or reference any code or files within the software package. For further details, please refer to the LICENSE files in the root directory.",
4
  "时长必须为正数": "Duration must be a positive number",
5
  "请输入有效的浮点数": "Please enter a valid floating-point number",
6
  "使用情感参考音频": "Use emotion reference audio",
7
+ "使用情感向量控制": "Use emotion vectors",
8
  "使用情感描述文本控制": "Use text description to control emotion",
9
  "上传情感参考音频": "Upload emotion reference audio",
10
  "情感权重": "Emotion control weight",
11
  "喜": "Happy",
12
  "怒": "Angry",
13
  "哀": "Sad",
14
+ "惧": "Afraid",
15
+ "厌恶": "Disgusted",
16
+ "低落": "Melancholic",
17
+ "惊喜": "Surprised",
18
+ "平静": "Calm",
19
  "情感描述文本": "Emotion description",
20
+ "请输入情绪描述(或留空以自动使用目标文本作为情绪描述)": "Please input an emotion description (or leave blank to automatically use the main text prompt)",
21
  "高级生成参数设置": "Advanced generation parameter settings",
22
  "情感向量之和不能超过1.5,请调整后重试。": "The sum of the emotion vectors cannot exceed 1.5. Please adjust and try again.",
23
+ "音色参考音频": "Voice Reference",
24
  "音频生成": "Speech Synthesis",
25
  "文本": "Text",
26
  "生成语音": "Synthesize",
27
  "生成结果": "Synthesis Result",
28
  "功能设置": "Settings",
29
+ "分句设置": "Text segmentation settings",
30
+ "参数会影响音频质量和生成速度": "These parameters affect the audio quality and generation speed.",
31
+ "分句最大Token数": "Max tokens per generation segment",
32
+ "建议80~200之间,值越大,分句越长;值越小,分句越碎;过小过大都可能导致音频质量不高": "Recommended range: 80 - 200. Larger values require more VRAM but improves the flow of the speech, while lower values require less VRAM but means more fragmented sentences. Values that are too small or too large may lead to less coherent speech.",
33
+ "预览分句结果": "Preview of the audio generation segments",
34
  "序号": "Index",
35
  "分句内容": "Content",
36
  "Token数": "Token Count",
37
  "情感控制方式": "Emotion control method",
38
  "GPT2 采样设置": "GPT-2 Sampling Configuration",
39
+ "参数会影响音频多样性和生成速度详见": "Influences both the diversity of the generated audio and the generation speed. For further details, refer to",
40
+ "是否进行采样": "Enable GPT-2 sampling",
41
+ "生成Token最大数量,���小导致音频被截断": "Maximum number of tokens to generate. If text exceeds this, the audio will be cut off.",
42
+ "请上传情感参考音频": "Please upload the emotion reference audio",
43
+ "当前模型版本": "Current model version: ",
44
+ "请输入目标文本": "Please input the text to synthesize",
45
+ "例如:委屈巴巴、危险在悄悄逼近": "e.g. deeply sad, danger is creeping closer",
46
  "与音色参考音频相同": "Same as the voice reference",
47
+ "情感随机采样": "Randomize emotion sampling",
48
+ "显示实验功能": "Show experimental features"
49
  }
tools/i18n/locale/zh_CN.json CHANGED
@@ -1,5 +1,5 @@
1
  {
2
- "本软件以MIT协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责.": "本软件以MIT协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责.",
3
  "如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录LICENSE.": "如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录LICENSE.",
4
  "时长必须为正数": "时长必须为正数",
5
  "请输入有效的浮点数": "请输入有效的浮点数",
@@ -17,7 +17,7 @@
17
  "惊喜": "惊喜",
18
  "平静": "平静",
19
  "情感描述文本": "情感描述文本",
20
- "请输入情感描述文本": "请输入情感描述文本",
21
  "高级生成参数设置": "高级生成参数设置",
22
  "情感向量之和不能超过1.5,请调整后重试。": "情感向量之和不能超过1.5,请调整后重试。",
23
  "音色参考音频": "音色参考音频",
@@ -36,5 +36,9 @@
36
  "Token数": "Token数",
37
  "情感控制方式": "情感控制方式",
38
  "GPT2 采样设置": "GPT2 采样设置",
39
- "参数会影响音频多样性和生成速度详见": "参数会影响音频多样性和生成速度详见"
 
 
 
 
40
  }
 
1
  {
2
+ "本软件以自拟协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责.": "本软件以自拟协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责.",
3
  "如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录LICENSE.": "如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录LICENSE.",
4
  "时长必须为正数": "时长必须为正数",
5
  "请输入有效的浮点数": "请输入有效的浮点数",
 
17
  "惊喜": "惊喜",
18
  "平静": "平静",
19
  "情感描述文本": "情感描述文本",
20
+ "请输入情绪描述(或留空以自动使用目标文本作为情绪描述)": "请输入情绪描述(或留空以自动使用目标文本作为情绪描述)",
21
  "高级生成参数设置": "高级生成参数设置",
22
  "情感向量之和不能超过1.5,请调整后重试。": "情感向量之和不能超过1.5,请调整后重试。",
23
  "音色参考音频": "音色参考音频",
 
36
  "Token数": "Token数",
37
  "情感控制方式": "情感控制方式",
38
  "GPT2 采样设置": "GPT2 采样设置",
39
+ "参数会影响音频多样性和生成速度详见": "参数会影响音频多样性和生成速度详见",
40
+ "是否进行采样": "是否进行采样",
41
+ "生成Token最大数量,过小导致音频被截断": "生成Token最大数量,过小导致音频被截断",
42
+ "显示实验功能": "显示实验功能",
43
+ "例如:委屈巴巴、危险在悄悄逼近": "例如:委屈巴巴、危险在悄悄逼近"
44
  }
webui.py CHANGED
@@ -1,6 +1,4 @@
1
  import json
2
- import logging
3
- import spaces
4
  import os
5
  import sys
6
  import threading
@@ -8,40 +6,60 @@ import time
8
 
9
  import warnings
10
 
11
- import pandas as pd
12
 
13
  warnings.filterwarnings("ignore", category=FutureWarning)
14
  warnings.filterwarnings("ignore", category=UserWarning)
15
 
 
 
16
  current_dir = os.path.dirname(os.path.abspath(__file__))
17
  sys.path.append(current_dir)
18
  sys.path.append(os.path.join(current_dir, "indextts"))
19
 
20
  import argparse
21
- parser = argparse.ArgumentParser(description="IndexTTS WebUI")
 
 
 
22
  parser.add_argument("--verbose", action="store_true", default=False, help="Enable verbose mode")
23
  parser.add_argument("--port", type=int, default=7860, help="Port to run the web UI on")
24
  parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to run the web UI on")
25
- parser.add_argument("--model_dir", type=str, default="checkpoints", help="Model checkpoints directory")
26
- parser.add_argument("--is_fp16", action="store_true", default=False, help="Fp16 infer")
 
 
 
27
  cmd_args = parser.parse_args()
28
 
29
- from tools.download_files import download_model_from_huggingface
30
- download_model_from_huggingface(os.path.join(current_dir,"checkpoints"),
31
- os.path.join(current_dir, "checkpoints","hf_cache"))
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
  import gradio as gr
34
- from indextts import infer
35
  from indextts.infer_v2 import IndexTTS2
36
  from tools.i18n.i18n import I18nAuto
37
- from modelscope.hub import api
38
 
39
  i18n = I18nAuto(language="Auto")
40
  MODE = 'local'
41
  tts = IndexTTS2(model_dir=cmd_args.model_dir,
42
  cfg_path=os.path.join(cmd_args.model_dir, "config.yaml"),
43
- is_fp16=False,use_cuda_kernel=False)
44
-
 
 
45
  # 支持的语言列表
46
  LANGUAGES = {
47
  "中文": "zh_CN",
@@ -51,6 +69,9 @@ EMO_CHOICES = [i18n("与音色参考音频相同"),
51
  i18n("使用情感参考音频"),
52
  i18n("使用情感向量控制"),
53
  i18n("使用情感描述文本控制")]
 
 
 
54
  os.makedirs("outputs/tasks",exist_ok=True)
55
  os.makedirs("prompts",exist_ok=True)
56
 
@@ -79,15 +100,23 @@ with open("examples/cases.jsonl", "r", encoding="utf-8") as f:
79
  example.get("emo_vec_5",0),
80
  example.get("emo_vec_6",0),
81
  example.get("emo_vec_7",0),
82
- example.get("emo_vec_8",0)]
 
83
  )
84
 
85
- @spaces.GPU
 
 
 
 
 
 
 
86
  def gen_single(emo_control_method,prompt, text,
87
  emo_ref_path, emo_weight,
88
  vec1, vec2, vec3, vec4, vec5, vec6, vec7, vec8,
89
  emo_text,emo_random,
90
- max_text_tokens_per_sentence=120,
91
  *args, progress=gr.Progress()):
92
  output_path = None
93
  if not output_path:
@@ -110,28 +139,31 @@ def gen_single(emo_control_method,prompt, text,
110
  }
111
  if type(emo_control_method) is not int:
112
  emo_control_method = emo_control_method.value
113
- if emo_control_method == 0:
114
- emo_ref_path = None
115
- emo_weight = 1.0
116
- if emo_control_method == 1:
117
- emo_weight = emo_weight
118
- if emo_control_method == 2:
 
119
  vec = [vec1, vec2, vec3, vec4, vec5, vec6, vec7, vec8]
120
- vec_sum = sum([vec1, vec2, vec3, vec4, vec5, vec6, vec7, vec8])
121
- if vec_sum > 1.5:
122
- gr.Warning(i18n("情感向量之和不能超过1.5,请调整后重试。"))
123
- return
124
  else:
 
125
  vec = None
126
 
127
- print(f"Emo control mode:{emo_control_method},vec:{vec}")
 
 
 
 
128
  output = tts.infer(spk_audio_prompt=prompt, text=text,
129
  output_path=output_path,
130
  emo_audio_prompt=emo_ref_path, emo_alpha=emo_weight,
131
  emo_vector=vec,
132
  use_emo_text=(emo_control_method==3), emo_text=emo_text,use_random=emo_random,
133
  verbose=cmd_args.verbose,
134
- max_text_tokens_per_sentence=int(max_text_tokens_per_sentence),
135
  **kwargs)
136
  return gr.update(value=output,visible=True)
137
 
@@ -147,6 +179,7 @@ with gr.Blocks(title="IndexTTS Demo") as demo:
147
  <a href='https://arxiv.org/abs/2506.21619'><img src='https://img.shields.io/badge/ArXiv-2506.21619-red'></a>
148
  </p>
149
  ''')
 
150
  with gr.Tab(i18n("音频生成")):
151
  with gr.Row():
152
  os.makedirs("prompts",exist_ok=True)
@@ -160,49 +193,54 @@ with gr.Blocks(title="IndexTTS Demo") as demo:
160
  input_text_single = gr.TextArea(label=i18n("文本"),key="input_text_single", placeholder=i18n("请输入目标文本"), info=f"{i18n('当前模型版本')}{tts.model_version or '1.0'}")
161
  gen_button = gr.Button(i18n("生成语音"), key="gen_button",interactive=True)
162
  output_audio = gr.Audio(label=i18n("生成结果"), visible=True,key="output_audio")
 
163
  with gr.Accordion(i18n("功能设置")):
164
  # 情感控制选项部分
165
  with gr.Row():
166
  emo_control_method = gr.Radio(
167
- choices=EMO_CHOICES,
168
  type="index",
169
- value=EMO_CHOICES[0],label=i18n("情感控制方式"))
170
  # 情感参考音频部分
171
  with gr.Group(visible=False) as emotion_reference_group:
172
  with gr.Row():
173
  emo_upload = gr.Audio(label=i18n("上传情感参考音频"), type="filepath")
174
 
175
- with gr.Row():
176
- emo_weight = gr.Slider(label=i18n("情感权重"), minimum=0.0, maximum=1.6, value=0.8, step=0.01)
177
-
178
  # 情感随机采样
179
- with gr.Row():
180
- emo_random = gr.Checkbox(label=i18n("情感随机采样"),value=False,visible=False)
181
 
182
  # 情感向量控制部分
183
  with gr.Group(visible=False) as emotion_vector_group:
184
  with gr.Row():
185
  with gr.Column():
186
- vec1 = gr.Slider(label=i18n("喜"), minimum=0.0, maximum=1.4, value=0.0, step=0.05)
187
- vec2 = gr.Slider(label=i18n("怒"), minimum=0.0, maximum=1.4, value=0.0, step=0.05)
188
- vec3 = gr.Slider(label=i18n("哀"), minimum=0.0, maximum=1.4, value=0.0, step=0.05)
189
- vec4 = gr.Slider(label=i18n("惧"), minimum=0.0, maximum=1.4, value=0.0, step=0.05)
190
  with gr.Column():
191
- vec5 = gr.Slider(label=i18n("厌恶"), minimum=0.0, maximum=1.4, value=0.0, step=0.05)
192
- vec6 = gr.Slider(label=i18n("低落"), minimum=0.0, maximum=1.4, value=0.0, step=0.05)
193
- vec7 = gr.Slider(label=i18n("惊喜"), minimum=0.0, maximum=1.4, value=0.0, step=0.05)
194
- vec8 = gr.Slider(label=i18n("平静"), minimum=0.0, maximum=1.4, value=0.0, step=0.05)
195
 
196
  with gr.Group(visible=False) as emo_text_group:
197
  with gr.Row():
198
- emo_text = gr.Textbox(label=i18n("情感描述文本"), placeholder=i18n("请输入情感描述文本"), value="", info=i18n("例如:高兴,愤怒,悲伤等"))
 
 
 
 
 
 
 
199
 
200
- with gr.Accordion(i18n("高级生成参数设置"), open=False):
201
  with gr.Row():
202
  with gr.Column(scale=1):
203
- gr.Markdown(f"**{i18n('GPT2 采样设置')}** _{i18n('参数会影响音频多样性和生成速度详见')}[Generation strategies](https://huggingface.co/docs/transformers/main/en/generation_strategies)_")
204
  with gr.Row():
205
- do_sample = gr.Checkbox(label="do_sample", value=True, info="是否进行采样")
206
  temperature = gr.Slider(label="temperature", minimum=0.1, maximum=2.0, value=0.8, step=0.1)
207
  with gr.Row():
208
  top_p = gr.Slider(label="top_p", minimum=0.0, maximum=1.0, value=0.8, step=0.01)
@@ -211,21 +249,22 @@ with gr.Blocks(title="IndexTTS Demo") as demo:
211
  with gr.Row():
212
  repetition_penalty = gr.Number(label="repetition_penalty", precision=None, value=10.0, minimum=0.1, maximum=20.0, step=0.1)
213
  length_penalty = gr.Number(label="length_penalty", precision=None, value=0.0, minimum=-2.0, maximum=2.0, step=0.1)
214
- max_mel_tokens = gr.Slider(label="max_mel_tokens", value=1500, minimum=50, maximum=tts.cfg.gpt.max_mel_tokens, step=10, info="生成Token最大数量,过小导致音频被截断", key="max_mel_tokens")
215
  # with gr.Row():
216
  # typical_sampling = gr.Checkbox(label="typical_sampling", value=False, info="不建议使用")
217
  # typical_mass = gr.Slider(label="typical_mass", value=0.9, minimum=0.0, maximum=1.0, step=0.1)
218
  with gr.Column(scale=2):
219
  gr.Markdown(f'**{i18n("分句设置")}** _{i18n("参数会影响音频质量和生成速度")}_')
220
  with gr.Row():
221
- max_text_tokens_per_sentence = gr.Slider(
222
- label=i18n("分句最大Token数"), value=120, minimum=20, maximum=tts.cfg.gpt.max_text_tokens, step=2, key="max_text_tokens_per_sentence",
 
223
  info=i18n("建议80~200之间,值越大,分句越长;值越小,分句越碎;过小过大都可能导致音频质量不高"),
224
  )
225
- with gr.Accordion(i18n("预览分句结果"), open=True) as sentences_settings:
226
- sentences_preview = gr.Dataframe(
227
  headers=[i18n("序号"), i18n("分句内容"), i18n("Token数")],
228
- key="sentences_preview",
229
  wrap=True,
230
  )
231
  advanced_params = [
@@ -234,8 +273,20 @@ with gr.Blocks(title="IndexTTS Demo") as demo:
234
  # typical_sampling, typical_mass,
235
  ]
236
 
237
- if len(example_cases) > 0:
238
- gr.Examples(
 
 
 
 
 
 
 
 
 
 
 
 
239
  examples=example_cases,
240
  examples_per_page=20,
241
  inputs=[prompt_audio,
@@ -244,71 +295,93 @@ with gr.Blocks(title="IndexTTS Demo") as demo:
244
  emo_upload,
245
  emo_weight,
246
  emo_text,
247
- vec1,vec2,vec3,vec4,vec5,vec6,vec7,vec8]
248
  )
249
 
250
- def on_input_text_change(text, max_tokens_per_sentence):
251
  if text and len(text) > 0:
252
  text_tokens_list = tts.tokenizer.tokenize(text)
253
 
254
- sentences = tts.tokenizer.split_sentences(text_tokens_list, max_tokens_per_sentence=int(max_tokens_per_sentence))
255
  data = []
256
- for i, s in enumerate(sentences):
257
- sentence_str = ''.join(s)
258
  tokens_count = len(s)
259
- data.append([i, sentence_str, tokens_count])
260
  return {
261
- sentences_preview: gr.update(value=data, visible=True, type="array"),
262
  }
263
  else:
264
  df = pd.DataFrame([], columns=[i18n("序号"), i18n("分句内容"), i18n("Token数")])
265
  return {
266
- sentences_preview: gr.update(value=df),
267
  }
 
268
  def on_method_select(emo_control_method):
269
- if emo_control_method == 1:
270
  return (gr.update(visible=True),
271
  gr.update(visible=False),
272
  gr.update(visible=False),
273
- gr.update(visible=False)
 
274
  )
275
- elif emo_control_method == 2:
276
  return (gr.update(visible=False),
277
  gr.update(visible=True),
278
  gr.update(visible=True),
 
279
  gr.update(visible=False)
280
  )
281
- elif emo_control_method == 3:
282
  return (gr.update(visible=False),
283
  gr.update(visible=True),
284
  gr.update(visible=False),
 
285
  gr.update(visible=True)
286
  )
287
- else:
288
  return (gr.update(visible=False),
 
289
  gr.update(visible=False),
290
  gr.update(visible=False),
291
  gr.update(visible=False)
292
  )
293
 
 
 
 
 
 
 
 
 
294
  emo_control_method.select(on_method_select,
295
  inputs=[emo_control_method],
296
  outputs=[emotion_reference_group,
297
- emo_random,
298
  emotion_vector_group,
299
- emo_text_group]
 
300
  )
301
 
302
  input_text_single.change(
303
  on_input_text_change,
304
- inputs=[input_text_single, max_text_tokens_per_sentence],
305
- outputs=[sentences_preview]
306
  )
307
- max_text_tokens_per_sentence.change(
 
 
 
 
 
 
 
308
  on_input_text_change,
309
- inputs=[input_text_single, max_text_tokens_per_sentence],
310
- outputs=[sentences_preview]
311
  )
 
312
  prompt_audio.upload(update_prompt_audio,
313
  inputs=[],
314
  outputs=[gen_button])
@@ -317,7 +390,7 @@ with gr.Blocks(title="IndexTTS Demo") as demo:
317
  inputs=[emo_control_method,prompt_audio, input_text_single, emo_upload, emo_weight,
318
  vec1, vec2, vec3, vec4, vec5, vec6, vec7, vec8,
319
  emo_text,emo_random,
320
- max_text_tokens_per_sentence,
321
  *advanced_params,
322
  ],
323
  outputs=[output_audio])
 
1
  import json
 
 
2
  import os
3
  import sys
4
  import threading
 
6
 
7
  import warnings
8
 
9
+ import numpy as np
10
 
11
  warnings.filterwarnings("ignore", category=FutureWarning)
12
  warnings.filterwarnings("ignore", category=UserWarning)
13
 
14
+ import pandas as pd
15
+
16
  current_dir = os.path.dirname(os.path.abspath(__file__))
17
  sys.path.append(current_dir)
18
  sys.path.append(os.path.join(current_dir, "indextts"))
19
 
20
  import argparse
21
+ parser = argparse.ArgumentParser(
22
+ description="IndexTTS WebUI",
23
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
24
+ )
25
  parser.add_argument("--verbose", action="store_true", default=False, help="Enable verbose mode")
26
  parser.add_argument("--port", type=int, default=7860, help="Port to run the web UI on")
27
  parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to run the web UI on")
28
+ parser.add_argument("--model_dir", type=str, default="./checkpoints", help="Model checkpoints directory")
29
+ parser.add_argument("--fp16", action="store_true", default=False, help="Use FP16 for inference if available")
30
+ parser.add_argument("--deepspeed", action="store_true", default=False, help="Use DeepSpeed to accelerate if available")
31
+ parser.add_argument("--cuda_kernel", action="store_true", default=False, help="Use CUDA kernel for inference if available")
32
+ parser.add_argument("--gui_seg_tokens", type=int, default=120, help="GUI: Max tokens per generation segment")
33
  cmd_args = parser.parse_args()
34
 
35
+ if not os.path.exists(cmd_args.model_dir):
36
+ print(f"Model directory {cmd_args.model_dir} does not exist. Please download the model first.")
37
+ sys.exit(1)
38
+
39
+ for file in [
40
+ "bpe.model",
41
+ "gpt.pth",
42
+ "config.yaml",
43
+ "s2mel.pth",
44
+ "wav2vec2bert_stats.pt"
45
+ ]:
46
+ file_path = os.path.join(cmd_args.model_dir, file)
47
+ if not os.path.exists(file_path):
48
+ print(f"Required file {file_path} does not exist. Please download it.")
49
+ sys.exit(1)
50
 
51
  import gradio as gr
 
52
  from indextts.infer_v2 import IndexTTS2
53
  from tools.i18n.i18n import I18nAuto
 
54
 
55
  i18n = I18nAuto(language="Auto")
56
  MODE = 'local'
57
  tts = IndexTTS2(model_dir=cmd_args.model_dir,
58
  cfg_path=os.path.join(cmd_args.model_dir, "config.yaml"),
59
+ use_fp16=cmd_args.fp16,
60
+ use_deepspeed=cmd_args.deepspeed,
61
+ use_cuda_kernel=cmd_args.cuda_kernel,
62
+ )
63
  # 支持的语言列表
64
  LANGUAGES = {
65
  "中文": "zh_CN",
 
69
  i18n("使用情感参考音频"),
70
  i18n("使用情感向量控制"),
71
  i18n("使用情感描述文本控制")]
72
+ EMO_CHOICES_BASE = EMO_CHOICES[:3] # 基础选项
73
+ EMO_CHOICES_EXPERIMENTAL = EMO_CHOICES # 全部选项(包括文本描述)
74
+
75
  os.makedirs("outputs/tasks",exist_ok=True)
76
  os.makedirs("prompts",exist_ok=True)
77
 
 
100
  example.get("emo_vec_5",0),
101
  example.get("emo_vec_6",0),
102
  example.get("emo_vec_7",0),
103
+ example.get("emo_vec_8",0),
104
+ example.get("emo_text") is not None]
105
  )
106
 
107
+ def normalize_emo_vec(emo_vec):
108
+ # emotion factors for better user experience
109
+ k_vec = [0.75,0.70,0.80,0.80,0.75,0.75,0.55,0.45]
110
+ tmp = np.array(k_vec) * np.array(emo_vec)
111
+ if np.sum(tmp) > 0.8:
112
+ tmp = tmp * 0.8/ np.sum(tmp)
113
+ return tmp.tolist()
114
+
115
  def gen_single(emo_control_method,prompt, text,
116
  emo_ref_path, emo_weight,
117
  vec1, vec2, vec3, vec4, vec5, vec6, vec7, vec8,
118
  emo_text,emo_random,
119
+ max_text_tokens_per_segment=120,
120
  *args, progress=gr.Progress()):
121
  output_path = None
122
  if not output_path:
 
139
  }
140
  if type(emo_control_method) is not int:
141
  emo_control_method = emo_control_method.value
142
+ if emo_control_method == 0: # emotion from speaker
143
+ emo_ref_path = None # remove external reference audio
144
+ if emo_control_method == 1: # emotion from reference audio
145
+ # normalize emo_alpha for better user experience
146
+ emo_weight = emo_weight * 0.8
147
+ pass
148
+ if emo_control_method == 2: # emotion from custom vectors
149
  vec = [vec1, vec2, vec3, vec4, vec5, vec6, vec7, vec8]
150
+ vec = normalize_emo_vec(vec)
 
 
 
151
  else:
152
+ # don't use the emotion vector inputs for the other modes
153
  vec = None
154
 
155
+ if emo_text == "":
156
+ # erase empty emotion descriptions; `infer()` will then automatically use the main prompt
157
+ emo_text = None
158
+
159
+ print(f"Emo control mode:{emo_control_method},weight:{emo_weight},vec:{vec}")
160
  output = tts.infer(spk_audio_prompt=prompt, text=text,
161
  output_path=output_path,
162
  emo_audio_prompt=emo_ref_path, emo_alpha=emo_weight,
163
  emo_vector=vec,
164
  use_emo_text=(emo_control_method==3), emo_text=emo_text,use_random=emo_random,
165
  verbose=cmd_args.verbose,
166
+ max_text_tokens_per_segment=int(max_text_tokens_per_segment),
167
  **kwargs)
168
  return gr.update(value=output,visible=True)
169
 
 
179
  <a href='https://arxiv.org/abs/2506.21619'><img src='https://img.shields.io/badge/ArXiv-2506.21619-red'></a>
180
  </p>
181
  ''')
182
+
183
  with gr.Tab(i18n("音频生成")):
184
  with gr.Row():
185
  os.makedirs("prompts",exist_ok=True)
 
193
  input_text_single = gr.TextArea(label=i18n("文本"),key="input_text_single", placeholder=i18n("请输入目标文本"), info=f"{i18n('当前模型版本')}{tts.model_version or '1.0'}")
194
  gen_button = gr.Button(i18n("生成语音"), key="gen_button",interactive=True)
195
  output_audio = gr.Audio(label=i18n("生成结果"), visible=True,key="output_audio")
196
+ experimental_checkbox = gr.Checkbox(label=i18n("显示实验功能"),value=False)
197
  with gr.Accordion(i18n("功能设置")):
198
  # 情感控制选项部分
199
  with gr.Row():
200
  emo_control_method = gr.Radio(
201
+ choices=EMO_CHOICES_BASE,
202
  type="index",
203
+ value=EMO_CHOICES_BASE[0],label=i18n("情感控制方式"))
204
  # 情感参考音频部分
205
  with gr.Group(visible=False) as emotion_reference_group:
206
  with gr.Row():
207
  emo_upload = gr.Audio(label=i18n("上传情感参考音频"), type="filepath")
208
 
 
 
 
209
  # 情感随机采样
210
+ with gr.Row(visible=False) as emotion_randomize_group:
211
+ emo_random = gr.Checkbox(label=i18n("情感随机采样"), value=False)
212
 
213
  # 情感向量控制部分
214
  with gr.Group(visible=False) as emotion_vector_group:
215
  with gr.Row():
216
  with gr.Column():
217
+ vec1 = gr.Slider(label=i18n("喜"), minimum=0.0, maximum=1.0, value=0.0, step=0.05)
218
+ vec2 = gr.Slider(label=i18n("怒"), minimum=0.0, maximum=1.0, value=0.0, step=0.05)
219
+ vec3 = gr.Slider(label=i18n("哀"), minimum=0.0, maximum=1.0, value=0.0, step=0.05)
220
+ vec4 = gr.Slider(label=i18n("惧"), minimum=0.0, maximum=1.0, value=0.0, step=0.05)
221
  with gr.Column():
222
+ vec5 = gr.Slider(label=i18n("厌恶"), minimum=0.0, maximum=1.0, value=0.0, step=0.05)
223
+ vec6 = gr.Slider(label=i18n("低落"), minimum=0.0, maximum=1.0, value=0.0, step=0.05)
224
+ vec7 = gr.Slider(label=i18n("惊喜"), minimum=0.0, maximum=1.0, value=0.0, step=0.05)
225
+ vec8 = gr.Slider(label=i18n("平静"), minimum=0.0, maximum=1.0, value=0.0, step=0.05)
226
 
227
  with gr.Group(visible=False) as emo_text_group:
228
  with gr.Row():
229
+ emo_text = gr.Textbox(label=i18n("情感描述文本"),
230
+ placeholder=i18n("请输入情绪描述(或留空以自动使用目标文本作为情绪描述)"),
231
+ value="",
232
+ info=i18n("例如:委屈巴巴、危险在悄悄逼近"))
233
+
234
+
235
+ with gr.Row(visible=False) as emo_weight_group:
236
+ emo_weight = gr.Slider(label=i18n("情感权重"), minimum=0.0, maximum=1.0, value=0.8, step=0.01)
237
 
238
+ with gr.Accordion(i18n("高级生成参数设置"), open=False,visible=False) as advanced_settings_group:
239
  with gr.Row():
240
  with gr.Column(scale=1):
241
+ gr.Markdown(f"**{i18n('GPT2 采样设置')}** _{i18n('参数会影响音频多样性和生成速度详见')} [Generation strategies](https://huggingface.co/docs/transformers/main/en/generation_strategies)._")
242
  with gr.Row():
243
+ do_sample = gr.Checkbox(label="do_sample", value=True, info=i18n("是否进行采样"))
244
  temperature = gr.Slider(label="temperature", minimum=0.1, maximum=2.0, value=0.8, step=0.1)
245
  with gr.Row():
246
  top_p = gr.Slider(label="top_p", minimum=0.0, maximum=1.0, value=0.8, step=0.01)
 
249
  with gr.Row():
250
  repetition_penalty = gr.Number(label="repetition_penalty", precision=None, value=10.0, minimum=0.1, maximum=20.0, step=0.1)
251
  length_penalty = gr.Number(label="length_penalty", precision=None, value=0.0, minimum=-2.0, maximum=2.0, step=0.1)
252
+ max_mel_tokens = gr.Slider(label="max_mel_tokens", value=1500, minimum=50, maximum=tts.cfg.gpt.max_mel_tokens, step=10, info=i18n("生成Token最大数量,过小导致音频被截断"), key="max_mel_tokens")
253
  # with gr.Row():
254
  # typical_sampling = gr.Checkbox(label="typical_sampling", value=False, info="不建议使用")
255
  # typical_mass = gr.Slider(label="typical_mass", value=0.9, minimum=0.0, maximum=1.0, step=0.1)
256
  with gr.Column(scale=2):
257
  gr.Markdown(f'**{i18n("分句设置")}** _{i18n("参数会影响音频质量和生成速度")}_')
258
  with gr.Row():
259
+ initial_value = max(20, min(tts.cfg.gpt.max_text_tokens, cmd_args.gui_seg_tokens))
260
+ max_text_tokens_per_segment = gr.Slider(
261
+ label=i18n("分句最大Token数"), value=initial_value, minimum=20, maximum=tts.cfg.gpt.max_text_tokens, step=2, key="max_text_tokens_per_segment",
262
  info=i18n("建议80~200之间,值越大,分句越长;值越小,分句越碎;过小过大都可能导致音频质量不高"),
263
  )
264
+ with gr.Accordion(i18n("预览分句结果"), open=True) as segments_settings:
265
+ segments_preview = gr.Dataframe(
266
  headers=[i18n("序号"), i18n("分句内容"), i18n("Token数")],
267
+ key="segments_preview",
268
  wrap=True,
269
  )
270
  advanced_params = [
 
273
  # typical_sampling, typical_mass,
274
  ]
275
 
276
+ if len(example_cases) > 2:
277
+ example_table = gr.Examples(
278
+ examples=example_cases[:-2],
279
+ examples_per_page=20,
280
+ inputs=[prompt_audio,
281
+ emo_control_method,
282
+ input_text_single,
283
+ emo_upload,
284
+ emo_weight,
285
+ emo_text,
286
+ vec1,vec2,vec3,vec4,vec5,vec6,vec7,vec8,experimental_checkbox]
287
+ )
288
+ elif len(example_cases) > 0:
289
+ example_table = gr.Examples(
290
  examples=example_cases,
291
  examples_per_page=20,
292
  inputs=[prompt_audio,
 
295
  emo_upload,
296
  emo_weight,
297
  emo_text,
298
+ vec1, vec2, vec3, vec4, vec5, vec6, vec7, vec8, experimental_checkbox]
299
  )
300
 
301
+ def on_input_text_change(text, max_text_tokens_per_segment):
302
  if text and len(text) > 0:
303
  text_tokens_list = tts.tokenizer.tokenize(text)
304
 
305
+ segments = tts.tokenizer.split_segments(text_tokens_list, max_text_tokens_per_segment=int(max_text_tokens_per_segment))
306
  data = []
307
+ for i, s in enumerate(segments):
308
+ segment_str = ''.join(s)
309
  tokens_count = len(s)
310
+ data.append([i, segment_str, tokens_count])
311
  return {
312
+ segments_preview: gr.update(value=data, visible=True, type="array"),
313
  }
314
  else:
315
  df = pd.DataFrame([], columns=[i18n("序号"), i18n("分句内容"), i18n("Token数")])
316
  return {
317
+ segments_preview: gr.update(value=df),
318
  }
319
+
320
  def on_method_select(emo_control_method):
321
+ if emo_control_method == 1: # emotion reference audio
322
  return (gr.update(visible=True),
323
  gr.update(visible=False),
324
  gr.update(visible=False),
325
+ gr.update(visible=False),
326
+ gr.update(visible=True)
327
  )
328
+ elif emo_control_method == 2: # emotion vectors
329
  return (gr.update(visible=False),
330
  gr.update(visible=True),
331
  gr.update(visible=True),
332
+ gr.update(visible=False),
333
  gr.update(visible=False)
334
  )
335
+ elif emo_control_method == 3: # emotion text description
336
  return (gr.update(visible=False),
337
  gr.update(visible=True),
338
  gr.update(visible=False),
339
+ gr.update(visible=True),
340
  gr.update(visible=True)
341
  )
342
+ else: # 0: same as speaker voice
343
  return (gr.update(visible=False),
344
+ gr.update(visible=False),
345
  gr.update(visible=False),
346
  gr.update(visible=False),
347
  gr.update(visible=False)
348
  )
349
 
350
+ def on_experimental_change(is_exp):
351
+ # 切换情感控制选项
352
+ # 第三个返回值实际没有起作用
353
+ if is_exp:
354
+ return gr.update(choices=EMO_CHOICES_EXPERIMENTAL, value=EMO_CHOICES_EXPERIMENTAL[0]), gr.update(visible=True),gr.update(value=example_cases)
355
+ else:
356
+ return gr.update(choices=EMO_CHOICES_BASE, value=EMO_CHOICES_BASE[0]), gr.update(visible=False),gr.update(value=example_cases[:-2])
357
+
358
  emo_control_method.select(on_method_select,
359
  inputs=[emo_control_method],
360
  outputs=[emotion_reference_group,
361
+ emotion_randomize_group,
362
  emotion_vector_group,
363
+ emo_text_group,
364
+ emo_weight_group]
365
  )
366
 
367
  input_text_single.change(
368
  on_input_text_change,
369
+ inputs=[input_text_single, max_text_tokens_per_segment],
370
+ outputs=[segments_preview]
371
  )
372
+
373
+ experimental_checkbox.change(
374
+ on_experimental_change,
375
+ inputs=[experimental_checkbox],
376
+ outputs=[emo_control_method, advanced_settings_group,example_table.dataset] # 高级参数Accordion
377
+ )
378
+
379
+ max_text_tokens_per_segment.change(
380
  on_input_text_change,
381
+ inputs=[input_text_single, max_text_tokens_per_segment],
382
+ outputs=[segments_preview]
383
  )
384
+
385
  prompt_audio.upload(update_prompt_audio,
386
  inputs=[],
387
  outputs=[gen_button])
 
390
  inputs=[emo_control_method,prompt_audio, input_text_single, emo_upload, emo_weight,
391
  vec1, vec2, vec3, vec4, vec5, vec6, vec7, vec8,
392
  emo_text,emo_random,
393
+ max_text_tokens_per_segment,
394
  *advanced_params,
395
  ],
396
  outputs=[output_audio])