import json import os import time from contextlib import contextmanager from typing import Optional from unicodedata import normalize import numpy as np import onnxruntime as ort import soundfile as sf from huggingface_hub import snapshot_download class UnicodeProcessor: def __init__(self, unicode_indexer_path: str): with open(unicode_indexer_path, "r") as f: self.indexer = json.load(f) def _preprocess_text(self, text: str) -> str: # TODO: add more preprocessing text = normalize("NFKD", text) return text def _get_text_mask(self, text_ids_lengths: np.ndarray) -> np.ndarray: text_mask = length_to_mask(text_ids_lengths) return text_mask def _text_to_unicode_values(self, text: str) -> np.ndarray: unicode_values = np.array( [ord(char) for char in text], dtype=np.uint16 ) # 2 bytes return unicode_values def __call__(self, text_list: list[str]) -> tuple[np.ndarray, np.ndarray]: text_list = [self._preprocess_text(t) for t in text_list] text_ids_lengths = np.array([len(text) for text in text_list], dtype=np.int64) text_ids = np.zeros((len(text_list), text_ids_lengths.max()), dtype=np.int64) for i, text in enumerate(text_list): unicode_vals = self._text_to_unicode_values(text) text_ids[i, : len(unicode_vals)] = np.array( [self.indexer[val] for val in unicode_vals], dtype=np.int64 ) text_mask = self._get_text_mask(text_ids_lengths) return text_ids, text_mask class Style: def __init__(self, style_ttl_onnx: np.ndarray, style_dp_onnx: np.ndarray): self.ttl = style_ttl_onnx self.dp = style_dp_onnx class TextToSpeech: def __init__( self, cfgs: dict, text_processor: UnicodeProcessor, dp_ort: ort.InferenceSession, text_enc_ort: ort.InferenceSession, vector_est_ort: ort.InferenceSession, vocoder_ort: ort.InferenceSession, ): self.cfgs = cfgs self.text_processor = text_processor self.dp_ort = dp_ort self.text_enc_ort = text_enc_ort self.vector_est_ort = vector_est_ort self.vocoder_ort = vocoder_ort self.sample_rate = cfgs["ae"]["sample_rate"] self.base_chunk_size = cfgs["ae"]["base_chunk_size"] self.chunk_compress_factor = cfgs["ttl"]["chunk_compress_factor"] self.ldim = cfgs["ttl"]["latent_dim"] def sample_noisy_latent( self, duration: np.ndarray ) -> tuple[np.ndarray, np.ndarray]: bsz = len(duration) wav_len_max = duration.max() * self.sample_rate wav_lengths = (duration * self.sample_rate).astype(np.int64) chunk_size = self.base_chunk_size * self.chunk_compress_factor latent_len = ((wav_len_max + chunk_size - 1) / chunk_size).astype(np.int32) latent_dim = self.ldim * self.chunk_compress_factor noisy_latent = np.random.randn(bsz, latent_dim, latent_len).astype(np.float32) latent_mask = get_latent_mask( wav_lengths, self.base_chunk_size, self.chunk_compress_factor ) noisy_latent = noisy_latent * latent_mask return noisy_latent, latent_mask def _infer( self, text_list: list[str], style: Style, total_step: int, speed: float = 1.05 ) -> tuple[np.ndarray, np.ndarray]: assert ( len(text_list) == style.ttl.shape[0] ), "Number of texts must match number of style vectors" bsz = len(text_list) text_ids, text_mask = self.text_processor(text_list) dur_onnx, *_ = self.dp_ort.run( None, {"text_ids": text_ids, "style_dp": style.dp, "text_mask": text_mask} ) dur_onnx = dur_onnx / speed text_emb_onnx, *_ = self.text_enc_ort.run( None, {"text_ids": text_ids, "style_ttl": style.ttl, "text_mask": text_mask}, ) # dur_onnx: [bsz] xt, latent_mask = self.sample_noisy_latent(dur_onnx) total_step_np = np.array([total_step] * bsz, dtype=np.float32) for step in range(total_step): current_step = np.array([step] * bsz, dtype=np.float32) xt, *_ = self.vector_est_ort.run( None, { "noisy_latent": xt, "text_emb": text_emb_onnx, "style_ttl": style.ttl, "text_mask": text_mask, "latent_mask": latent_mask, "current_step": current_step, "total_step": total_step_np, }, ) wav, *_ = self.vocoder_ort.run(None, {"latent": xt}) return wav, dur_onnx def __call__( self, text: str, style: Style, total_step: int, speed: float = 1.05, silence_duration: float = 0.3, ) -> tuple[np.ndarray, np.ndarray]: assert ( style.ttl.shape[0] == 1 ), "Single speaker text to speech only supports single style" text_list = chunk_text(text) wav_cat = None dur_cat = None for text in text_list: wav, dur_onnx = self._infer([text], style, total_step, speed) if wav_cat is None: wav_cat = wav dur_cat = dur_onnx else: silence = np.zeros( (1, int(silence_duration * self.sample_rate)), dtype=np.float32 ) wav_cat = np.concatenate([wav_cat, silence, wav], axis=1) dur_cat += dur_onnx + silence_duration return wav_cat, dur_cat def batch( self, text_list: list[str], style: Style, total_step: int, speed: float = 1.05 ) -> tuple[np.ndarray, np.ndarray]: return self._infer(text_list, style, total_step, speed) def length_to_mask(lengths: np.ndarray, max_len: Optional[int] = None) -> np.ndarray: """ Convert lengths to binary mask. Args: lengths: (B,) max_len: int Returns: mask: (B, 1, max_len) """ max_len = max_len or lengths.max() ids = np.arange(0, max_len) mask = (ids < np.expand_dims(lengths, axis=1)).astype(np.float32) return mask.reshape(-1, 1, max_len) def get_latent_mask( wav_lengths: np.ndarray, base_chunk_size: int, chunk_compress_factor: int ) -> np.ndarray: latent_size = base_chunk_size * chunk_compress_factor latent_lengths = (wav_lengths + latent_size - 1) // latent_size latent_mask = length_to_mask(latent_lengths) return latent_mask def load_onnx( onnx_path: str, opts: ort.SessionOptions, providers: list[str] ) -> ort.InferenceSession: return ort.InferenceSession(onnx_path, sess_options=opts, providers=providers) def load_onnx_all( onnx_dir: str, opts: ort.SessionOptions, providers: list[str] ) -> tuple[ ort.InferenceSession, ort.InferenceSession, ort.InferenceSession, ort.InferenceSession, ]: dp_onnx_path = os.path.join(onnx_dir, "duration_predictor.onnx") text_enc_onnx_path = os.path.join(onnx_dir, "text_encoder.onnx") vector_est_onnx_path = os.path.join(onnx_dir, "vector_estimator.onnx") vocoder_onnx_path = os.path.join(onnx_dir, "vocoder.onnx") dp_ort = load_onnx(dp_onnx_path, opts, providers) text_enc_ort = load_onnx(text_enc_onnx_path, opts, providers) vector_est_ort = load_onnx(vector_est_onnx_path, opts, providers) vocoder_ort = load_onnx(vocoder_onnx_path, opts, providers) return dp_ort, text_enc_ort, vector_est_ort, vocoder_ort def load_cfgs(onnx_dir: str) -> dict: cfg_path = os.path.join(onnx_dir, "tts.json") with open(cfg_path, "r") as f: cfgs = json.load(f) return cfgs def load_text_processor(onnx_dir: str) -> UnicodeProcessor: unicode_indexer_path = os.path.join(onnx_dir, "unicode_indexer.json") text_processor = UnicodeProcessor(unicode_indexer_path) return text_processor def load_text_to_speech(onnx_dir: str, use_gpu: bool = False) -> TextToSpeech: opts = ort.SessionOptions() if use_gpu: raise NotImplementedError("GPU mode is not fully tested") else: providers = ["CPUExecutionProvider"] print("Using CPU for inference") cfgs = load_cfgs(onnx_dir) dp_ort, text_enc_ort, vector_est_ort, vocoder_ort = load_onnx_all( onnx_dir, opts, providers ) text_processor = load_text_processor(onnx_dir) return TextToSpeech( cfgs, text_processor, dp_ort, text_enc_ort, vector_est_ort, vocoder_ort ) def load_voice_style(voice_style_paths: list[str], verbose: bool = False) -> Style: bsz = len(voice_style_paths) # Read first file to get dimensions with open(voice_style_paths[0], "r") as f: first_style = json.load(f) ttl_dims = first_style["style_ttl"]["dims"] dp_dims = first_style["style_dp"]["dims"] # Pre-allocate arrays with full batch size ttl_style = np.zeros([bsz, ttl_dims[1], ttl_dims[2]], dtype=np.float32) dp_style = np.zeros([bsz, dp_dims[1], dp_dims[2]], dtype=np.float32) # Fill in the data for i, voice_style_path in enumerate(voice_style_paths): with open(voice_style_path, "r") as f: voice_style = json.load(f) ttl_data = np.array( voice_style["style_ttl"]["data"], dtype=np.float32 ).flatten() ttl_style[i] = ttl_data.reshape(ttl_dims[1], ttl_dims[2]) dp_data = np.array(voice_style["style_dp"]["data"], dtype=np.float32).flatten() dp_style[i] = dp_data.reshape(dp_dims[1], dp_dims[2]) if verbose: print(f"Loaded {bsz} voice styles") return Style(ttl_style, dp_style) @contextmanager def timer(name: str): start = time.time() print(f"{name}...") yield print(f" -> {name} completed in {time.time() - start:.2f} sec") def sanitize_filename(text: str, max_len: int) -> str: """Sanitize filename by replacing non-alphanumeric characters with underscores""" import re prefix = text[:max_len] return re.sub(r"[^a-zA-Z0-9]", "_", prefix) def chunk_text(text: str, max_len: int = 300) -> list[str]: """ Split text into chunks by paragraphs and sentences. Args: text: Input text to chunk max_len: Maximum length of each chunk (default: 300) Returns: List of text chunks """ import re # Split by paragraph (two or more newlines) paragraphs = [p.strip() for p in re.split(r"\n\s*\n+", text.strip()) if p.strip()] chunks = [] for paragraph in paragraphs: paragraph = paragraph.strip() if not paragraph: continue # Split by sentence boundaries (period, question mark, exclamation mark followed by space) # But exclude common abbreviations like Mr., Mrs., Dr., etc. and single capital letters like F. pattern = r"(?