Spaces:
Runtime error
Runtime error
| # Copyright (c) 2024 Tsinghua Univ. (authors: Xingchen Song) | |
| # 2025 (authors: Yuekai Zhang) | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # Modified from https://github.com/xingchensong/S3Tokenizer/blob/main/s3tokenizer/cli.py | |
| """ Example Usage | |
| torchrun --nproc_per_node=1 \ | |
| benchmark.py --output-dir $log_dir \ | |
| --batch-size $batch_size \ | |
| --enable-warmup \ | |
| --split-name $split_name \ | |
| --model-path $F5_TTS_HF_DOWNLOAD_PATH/$model/model_1200000.pt \ | |
| --vocab-file $F5_TTS_HF_DOWNLOAD_PATH/$model/vocab.txt \ | |
| --vocoder-trt-engine-path $vocoder_trt_engine_path \ | |
| --backend-type $backend_type \ | |
| --tllm-model-dir $F5_TTS_TRT_LLM_ENGINE_PATH || exit 1 | |
| """ | |
| import argparse | |
| import json | |
| import os | |
| import time | |
| from typing import Dict, List, Union | |
| import datasets | |
| import jieba | |
| import tensorrt as trt | |
| import torch | |
| import torch.distributed as dist | |
| import torch.nn.functional as F | |
| import torchaudio | |
| from datasets import load_dataset | |
| from f5_tts_trtllm import F5TTS | |
| from huggingface_hub import hf_hub_download | |
| from pypinyin import Style, lazy_pinyin | |
| from tensorrt_llm._utils import trt_dtype_to_torch | |
| from tensorrt_llm.logger import logger | |
| from tensorrt_llm.runtime.session import Session, TensorInfo | |
| from torch.nn.utils.rnn import pad_sequence | |
| from torch.utils.data import DataLoader, DistributedSampler | |
| from tqdm import tqdm | |
| from vocos import Vocos | |
| torch.manual_seed(0) | |
| def get_args(): | |
| parser = argparse.ArgumentParser(description="extract speech code") | |
| parser.add_argument( | |
| "--split-name", | |
| type=str, | |
| default="wenetspeech4tts", | |
| choices=["wenetspeech4tts", "test_zh", "test_en", "test_hard"], | |
| help="huggingface dataset split name", | |
| ) | |
| parser.add_argument("--output-dir", required=True, type=str, help="dir to save result") | |
| parser.add_argument( | |
| "--vocab-file", | |
| required=True, | |
| type=str, | |
| help="vocab file", | |
| ) | |
| parser.add_argument( | |
| "--model-path", | |
| required=True, | |
| type=str, | |
| help="model path, to load text embedding", | |
| ) | |
| parser.add_argument( | |
| "--tllm-model-dir", | |
| required=True, | |
| type=str, | |
| help="tllm model dir", | |
| ) | |
| parser.add_argument( | |
| "--batch-size", | |
| required=True, | |
| type=int, | |
| help="batch size (per-device) for inference", | |
| ) | |
| parser.add_argument("--num-workers", type=int, default=0, help="workers for dataloader") | |
| parser.add_argument("--prefetch", type=int, default=None, help="prefetch for dataloader") | |
| parser.add_argument( | |
| "--vocoder", | |
| default="vocos", | |
| type=str, | |
| help="vocoder name", | |
| ) | |
| parser.add_argument( | |
| "--vocoder-trt-engine-path", | |
| default=None, | |
| type=str, | |
| help="vocoder trt engine path", | |
| ) | |
| parser.add_argument("--enable-warmup", action="store_true") | |
| parser.add_argument("--remove-input-padding", action="store_true") | |
| parser.add_argument("--use-perf", action="store_true", help="use nvtx to record performance") | |
| parser.add_argument("--backend-type", type=str, default="triton", choices=["trt", "pytorch"], help="backend type") | |
| args = parser.parse_args() | |
| return args | |
| def padded_mel_batch(ref_mels, max_seq_len): | |
| padded_ref_mels = [] | |
| for mel in ref_mels: | |
| # pad along the last dimension | |
| padded_ref_mel = F.pad(mel, (0, 0, 0, max_seq_len - mel.shape[0]), value=0) | |
| padded_ref_mels.append(padded_ref_mel) | |
| padded_ref_mels = torch.stack(padded_ref_mels) | |
| return padded_ref_mels | |
| def data_collator(batch, vocab_char_map, device="cuda", use_perf=False): | |
| if use_perf: | |
| torch.cuda.nvtx.range_push("data_collator") | |
| target_sample_rate = 24000 | |
| target_rms = 0.1 | |
| ids, ref_mel_list, ref_mel_len_list, estimated_reference_target_mel_len, reference_target_texts_list = ( | |
| [], | |
| [], | |
| [], | |
| [], | |
| [], | |
| ) | |
| for i, item in enumerate(batch): | |
| item_id, prompt_text, target_text = ( | |
| item["id"], | |
| item["prompt_text"], | |
| item["target_text"], | |
| ) | |
| ids.append(item_id) | |
| reference_target_texts_list.append(prompt_text + target_text) | |
| ref_audio_org, ref_sr = ( | |
| item["prompt_audio"]["array"], | |
| item["prompt_audio"]["sampling_rate"], | |
| ) | |
| ref_audio_org = torch.from_numpy(ref_audio_org).unsqueeze(0).float() | |
| ref_rms = torch.sqrt(torch.mean(torch.square(ref_audio_org))) | |
| if ref_rms < target_rms: | |
| ref_audio_org = ref_audio_org * target_rms / ref_rms | |
| if ref_sr != target_sample_rate: | |
| resampler = torchaudio.transforms.Resample(ref_sr, target_sample_rate) | |
| ref_audio = resampler(ref_audio_org) | |
| else: | |
| ref_audio = ref_audio_org | |
| if use_perf: | |
| torch.cuda.nvtx.range_push(f"mel_spectrogram {i}") | |
| ref_mel = mel_spectrogram(ref_audio, vocoder="vocos", device="cuda") | |
| if use_perf: | |
| torch.cuda.nvtx.range_pop() | |
| ref_mel = ref_mel.squeeze() | |
| ref_mel_len = ref_mel.shape[0] | |
| assert ref_mel.shape[1] == 100 | |
| ref_mel_list.append(ref_mel) | |
| ref_mel_len_list.append(ref_mel_len) | |
| estimated_reference_target_mel_len.append( | |
| int(ref_mel.shape[0] * (1 + len(target_text.encode("utf-8")) / len(prompt_text.encode("utf-8")))) | |
| ) | |
| max_seq_len = max(estimated_reference_target_mel_len) | |
| ref_mel_batch = padded_mel_batch(ref_mel_list, max_seq_len) | |
| ref_mel_len_batch = torch.LongTensor(ref_mel_len_list) | |
| pinyin_list = convert_char_to_pinyin(reference_target_texts_list, polyphone=True) | |
| text_pad_sequence = list_str_to_idx(pinyin_list, vocab_char_map) | |
| for i, item in enumerate(text_pad_sequence): | |
| text_pad_sequence[i] = F.pad( | |
| item, (0, estimated_reference_target_mel_len[i] - len(item)), mode="constant", value=-1 | |
| ) | |
| text_pad_sequence[i] += 1 # WAR: 0 is reserved for padding token, hard coding in F5-TTS | |
| text_pad_sequence = pad_sequence(text_pad_sequence, padding_value=-1, batch_first=True).to(device) | |
| text_pad_sequence = F.pad( | |
| text_pad_sequence, (0, max_seq_len - text_pad_sequence.shape[1]), mode="constant", value=-1 | |
| ) | |
| if use_perf: | |
| torch.cuda.nvtx.range_pop() | |
| return { | |
| "ids": ids, | |
| "ref_mel_batch": ref_mel_batch, | |
| "ref_mel_len_batch": ref_mel_len_batch, | |
| "text_pad_sequence": text_pad_sequence, | |
| "estimated_reference_target_mel_len": estimated_reference_target_mel_len, | |
| } | |
| def init_distributed(): | |
| world_size = int(os.environ.get("WORLD_SIZE", 1)) | |
| local_rank = int(os.environ.get("LOCAL_RANK", 0)) | |
| rank = int(os.environ.get("RANK", 0)) | |
| print( | |
| "Inference on multiple gpus, this gpu {}".format(local_rank) | |
| + ", rank {}, world_size {}".format(rank, world_size) | |
| ) | |
| torch.cuda.set_device(local_rank) | |
| # Initialize process group with explicit device IDs | |
| dist.init_process_group( | |
| "nccl", | |
| ) | |
| return world_size, local_rank, rank | |
| def get_tokenizer(vocab_file_path: str): | |
| """ | |
| tokenizer - "pinyin" do g2p for only chinese characters, need .txt vocab_file | |
| - "char" for char-wise tokenizer, need .txt vocab_file | |
| - "byte" for utf-8 tokenizer | |
| - "custom" if you're directly passing in a path to the vocab.txt you want to use | |
| vocab_size - if use "pinyin", all available pinyin types, common alphabets (also those with accent) and symbols | |
| - if use "char", derived from unfiltered character & symbol counts of custom dataset | |
| - if use "byte", set to 256 (unicode byte range) | |
| """ | |
| with open(vocab_file_path, "r", encoding="utf-8") as f: | |
| vocab_char_map = {} | |
| for i, char in enumerate(f): | |
| vocab_char_map[char[:-1]] = i | |
| vocab_size = len(vocab_char_map) | |
| return vocab_char_map, vocab_size | |
| def convert_char_to_pinyin(reference_target_texts_list, polyphone=True): | |
| final_reference_target_texts_list = [] | |
| custom_trans = str.maketrans( | |
| {";": ",", "“": '"', "”": '"', "‘": "'", "’": "'"} | |
| ) # add custom trans here, to address oov | |
| def is_chinese(c): | |
| return "\u3100" <= c <= "\u9fff" # common chinese characters | |
| for text in reference_target_texts_list: | |
| char_list = [] | |
| text = text.translate(custom_trans) | |
| for seg in jieba.cut(text): | |
| seg_byte_len = len(bytes(seg, "UTF-8")) | |
| if seg_byte_len == len(seg): # if pure alphabets and symbols | |
| if char_list and seg_byte_len > 1 and char_list[-1] not in " :'\"": | |
| char_list.append(" ") | |
| char_list.extend(seg) | |
| elif polyphone and seg_byte_len == 3 * len(seg): # if pure east asian characters | |
| seg_ = lazy_pinyin(seg, style=Style.TONE3, tone_sandhi=True) | |
| for i, c in enumerate(seg): | |
| if is_chinese(c): | |
| char_list.append(" ") | |
| char_list.append(seg_[i]) | |
| else: # if mixed characters, alphabets and symbols | |
| for c in seg: | |
| if ord(c) < 256: | |
| char_list.extend(c) | |
| elif is_chinese(c): | |
| char_list.append(" ") | |
| char_list.extend(lazy_pinyin(c, style=Style.TONE3, tone_sandhi=True)) | |
| else: | |
| char_list.append(c) | |
| final_reference_target_texts_list.append(char_list) | |
| return final_reference_target_texts_list | |
| def list_str_to_idx( | |
| text: Union[List[str], List[List[str]]], | |
| vocab_char_map: Dict[str, int], # {char: idx} | |
| padding_value=-1, | |
| ): | |
| list_idx_tensors = [torch.tensor([vocab_char_map.get(c, 0) for c in t]) for t in text] # pinyin or char style | |
| # text = pad_sequence(list_idx_tensors, padding_value=padding_value, batch_first=True) | |
| return list_idx_tensors | |
| def load_vocoder( | |
| vocoder_name="vocos", is_local=False, local_path="", device="cuda", hf_cache_dir=None, vocoder_trt_engine_path=None | |
| ): | |
| if vocoder_name == "vocos": | |
| if vocoder_trt_engine_path is not None: | |
| vocoder = VocosTensorRT(engine_path=vocoder_trt_engine_path) | |
| else: | |
| # vocoder = Vocos.from_pretrained("charactr/vocos-mel-24khz").to(device) | |
| if is_local: | |
| print(f"Load vocos from local path {local_path}") | |
| config_path = f"{local_path}/config.yaml" | |
| model_path = f"{local_path}/pytorch_model.bin" | |
| else: | |
| print("Download Vocos from huggingface charactr/vocos-mel-24khz") | |
| repo_id = "charactr/vocos-mel-24khz" | |
| config_path = hf_hub_download(repo_id=repo_id, cache_dir=hf_cache_dir, filename="config.yaml") | |
| model_path = hf_hub_download(repo_id=repo_id, cache_dir=hf_cache_dir, filename="pytorch_model.bin") | |
| vocoder = Vocos.from_hparams(config_path) | |
| state_dict = torch.load(model_path, map_location="cpu", weights_only=True) | |
| from vocos.feature_extractors import EncodecFeatures | |
| if isinstance(vocoder.feature_extractor, EncodecFeatures): | |
| encodec_parameters = { | |
| "feature_extractor.encodec." + key: value | |
| for key, value in vocoder.feature_extractor.encodec.state_dict().items() | |
| } | |
| state_dict.update(encodec_parameters) | |
| vocoder.load_state_dict(state_dict) | |
| vocoder = vocoder.eval().to(device) | |
| elif vocoder_name == "bigvgan": | |
| raise NotImplementedError("BigVGAN is not implemented yet") | |
| return vocoder | |
| def mel_spectrogram(waveform, vocoder="vocos", device="cuda"): | |
| if vocoder == "vocos": | |
| mel_stft = torchaudio.transforms.MelSpectrogram( | |
| sample_rate=24000, | |
| n_fft=1024, | |
| win_length=1024, | |
| hop_length=256, | |
| n_mels=100, | |
| power=1, | |
| center=True, | |
| normalized=False, | |
| norm=None, | |
| ).to(device) | |
| mel = mel_stft(waveform.to(device)) | |
| mel = mel.clamp(min=1e-5).log() | |
| return mel.transpose(1, 2) | |
| class VocosTensorRT: | |
| def __init__(self, engine_path="./vocos_vocoder.plan", stream=None): | |
| TRT_LOGGER = trt.Logger(trt.Logger.WARNING) | |
| trt.init_libnvinfer_plugins(TRT_LOGGER, namespace="") | |
| logger.info(f"Loading vae engine from {engine_path}") | |
| self.engine_path = engine_path | |
| with open(engine_path, "rb") as f: | |
| engine_buffer = f.read() | |
| self.session = Session.from_serialized_engine(engine_buffer) | |
| self.stream = stream if stream is not None else torch.cuda.current_stream().cuda_stream | |
| def decode(self, mels): | |
| mels = mels.contiguous() | |
| inputs = {"mel": mels} | |
| output_info = self.session.infer_shapes([TensorInfo("mel", trt.DataType.FLOAT, mels.shape)]) | |
| outputs = { | |
| t.name: torch.empty(tuple(t.shape), dtype=trt_dtype_to_torch(t.dtype), device="cuda") for t in output_info | |
| } | |
| ok = self.session.run(inputs, outputs, self.stream) | |
| assert ok, "Runtime execution failed for vae session" | |
| samples = outputs["waveform"] | |
| return samples | |
| def main(): | |
| args = get_args() | |
| os.makedirs(args.output_dir, exist_ok=True) | |
| assert torch.cuda.is_available() | |
| world_size, local_rank, rank = init_distributed() | |
| device = torch.device(f"cuda:{local_rank}") | |
| vocab_char_map, vocab_size = get_tokenizer(args.vocab_file) | |
| tllm_model_dir = args.tllm_model_dir | |
| config_file = os.path.join(tllm_model_dir, "config.json") | |
| with open(config_file) as f: | |
| config = json.load(f) | |
| if args.backend_type == "trt": | |
| model = F5TTS( | |
| config, debug_mode=False, tllm_model_dir=tllm_model_dir, model_path=args.model_path, vocab_size=vocab_size | |
| ) | |
| elif args.backend_type == "pytorch": | |
| import sys | |
| sys.path.append(f"{os.path.dirname(os.path.abspath(__file__))}/../../../../src/") | |
| from f5_tts.infer.utils_infer import load_model | |
| from f5_tts.model import DiT | |
| F5TTS_model_cfg = dict( | |
| dim=1024, | |
| depth=22, | |
| heads=16, | |
| ff_mult=2, | |
| text_dim=512, | |
| conv_layers=4, | |
| pe_attn_head=1, | |
| text_mask_padding=False, | |
| ) | |
| model = load_model(DiT, F5TTS_model_cfg, args.model_path) | |
| vocoder = load_vocoder( | |
| vocoder_name=args.vocoder, device=device, vocoder_trt_engine_path=args.vocoder_trt_engine_path | |
| ) | |
| dataset = load_dataset( | |
| "yuekai/seed_tts", | |
| split=args.split_name, | |
| trust_remote_code=True, | |
| ) | |
| def add_estimated_duration(example): | |
| prompt_audio_len = example["prompt_audio"]["array"].shape[0] | |
| scale_factor = 1 + len(example["target_text"]) / len(example["prompt_text"]) | |
| estimated_duration = prompt_audio_len * scale_factor | |
| example["estimated_duration"] = estimated_duration / example["prompt_audio"]["sampling_rate"] | |
| return example | |
| dataset = dataset.map(add_estimated_duration) | |
| dataset = dataset.sort("estimated_duration", reverse=True) | |
| if args.use_perf: | |
| # dataset_list = [dataset.select(range(1)) for i in range(16)] # seq_len 1000 | |
| dataset_list_short = [dataset.select([24]) for i in range(8)] # seq_len 719 | |
| # dataset_list_long = [dataset.select([23]) for i in range(8)] # seq_len 2002 | |
| # dataset = datasets.concatenate_datasets(dataset_list_short + dataset_list_long) | |
| dataset = datasets.concatenate_datasets(dataset_list_short) | |
| if world_size > 1: | |
| sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank) | |
| else: | |
| # This would disable shuffling | |
| sampler = None | |
| dataloader = DataLoader( | |
| dataset, | |
| batch_size=args.batch_size, | |
| sampler=sampler, | |
| shuffle=False, | |
| num_workers=args.num_workers, | |
| prefetch_factor=args.prefetch, | |
| collate_fn=lambda x: data_collator(x, vocab_char_map, use_perf=args.use_perf), | |
| ) | |
| total_steps = len(dataset) | |
| if args.enable_warmup: | |
| for batch in dataloader: | |
| ref_mels, ref_mel_lens = batch["ref_mel_batch"].to(device), batch["ref_mel_len_batch"].to(device) | |
| text_pad_seq = batch["text_pad_sequence"].to(device) | |
| total_mel_lens = batch["estimated_reference_target_mel_len"] | |
| if args.backend_type == "trt": | |
| _ = model.sample( | |
| text_pad_seq, ref_mels, ref_mel_lens, total_mel_lens, remove_input_padding=args.remove_input_padding | |
| ) | |
| elif args.backend_type == "pytorch": | |
| with torch.inference_mode(): | |
| text_pad_seq -= 1 | |
| text_pad_seq[text_pad_seq == -2] = -1 | |
| total_mel_lens = torch.tensor(total_mel_lens, device=device) | |
| generated, _ = model.sample( | |
| cond=ref_mels, | |
| text=text_pad_seq, | |
| duration=total_mel_lens, | |
| steps=16, | |
| cfg_strength=2.0, | |
| sway_sampling_coef=-1, | |
| ) | |
| if rank == 0: | |
| progress_bar = tqdm(total=total_steps, desc="Processing", unit="wavs") | |
| decoding_time = 0 | |
| vocoder_time = 0 | |
| total_duration = 0 | |
| if args.use_perf: | |
| torch.cuda.cudart().cudaProfilerStart() | |
| total_decoding_time = time.time() | |
| for batch in dataloader: | |
| if args.use_perf: | |
| torch.cuda.nvtx.range_push("data sample") | |
| ref_mels, ref_mel_lens = batch["ref_mel_batch"].to(device), batch["ref_mel_len_batch"].to(device) | |
| text_pad_seq = batch["text_pad_sequence"].to(device) | |
| total_mel_lens = batch["estimated_reference_target_mel_len"] | |
| if args.use_perf: | |
| torch.cuda.nvtx.range_pop() | |
| if args.backend_type == "trt": | |
| generated, cost_time = model.sample( | |
| text_pad_seq, | |
| ref_mels, | |
| ref_mel_lens, | |
| total_mel_lens, | |
| remove_input_padding=args.remove_input_padding, | |
| use_perf=args.use_perf, | |
| ) | |
| elif args.backend_type == "pytorch": | |
| total_mel_lens = torch.tensor(total_mel_lens, device=device) | |
| with torch.inference_mode(): | |
| start_time = time.time() | |
| text_pad_seq -= 1 | |
| text_pad_seq[text_pad_seq == -2] = -1 | |
| generated, _ = model.sample( | |
| cond=ref_mels, | |
| text=text_pad_seq, | |
| duration=total_mel_lens, | |
| lens=ref_mel_lens, | |
| steps=16, | |
| cfg_strength=2.0, | |
| sway_sampling_coef=-1, | |
| ) | |
| cost_time = time.time() - start_time | |
| decoding_time += cost_time | |
| vocoder_start_time = time.time() | |
| for i, gen in enumerate(generated): | |
| gen = gen[ref_mel_lens[i] : total_mel_lens[i], :].unsqueeze(0) | |
| gen_mel_spec = gen.permute(0, 2, 1).to(torch.float32) | |
| if args.vocoder == "vocos": | |
| if args.use_perf: | |
| torch.cuda.nvtx.range_push("vocoder decode") | |
| generated_wave = vocoder.decode(gen_mel_spec).cpu() | |
| if args.use_perf: | |
| torch.cuda.nvtx.range_pop() | |
| else: | |
| generated_wave = vocoder(gen_mel_spec).squeeze(0).cpu() | |
| target_rms = 0.1 | |
| target_sample_rate = 24_000 | |
| # if ref_rms_list[i] < target_rms: | |
| # generated_wave = generated_wave * ref_rms_list[i] / target_rms | |
| rms = torch.sqrt(torch.mean(torch.square(generated_wave))) | |
| if rms < target_rms: | |
| generated_wave = generated_wave * target_rms / rms | |
| utt = batch["ids"][i] | |
| torchaudio.save( | |
| f"{args.output_dir}/{utt}.wav", | |
| generated_wave, | |
| target_sample_rate, | |
| ) | |
| total_duration += generated_wave.shape[1] / target_sample_rate | |
| vocoder_time += time.time() - vocoder_start_time | |
| if rank == 0: | |
| progress_bar.update(world_size * len(batch["ids"])) | |
| total_decoding_time = time.time() - total_decoding_time | |
| if rank == 0: | |
| progress_bar.close() | |
| rtf = total_decoding_time / total_duration | |
| s = f"RTF: {rtf:.4f}\n" | |
| s += f"total_duration: {total_duration:.3f} seconds\n" | |
| s += f"({total_duration / 3600:.2f} hours)\n" | |
| s += f"DiT time: {decoding_time:.3f} seconds ({decoding_time / 3600:.2f} hours)\n" | |
| s += f"Vocoder time: {vocoder_time:.3f} seconds ({vocoder_time / 3600:.2f} hours)\n" | |
| s += f"total decoding time: {total_decoding_time:.3f} seconds ({total_decoding_time / 3600:.2f} hours)\n" | |
| s += f"batch size: {args.batch_size}\n" | |
| print(s) | |
| with open(f"{args.output_dir}/rtf.txt", "w") as f: | |
| f.write(s) | |
| dist.barrier() | |
| dist.destroy_process_group() | |
| if __name__ == "__main__": | |
| main() | |