import os import torch import torch.nn.functional as F import torchaudio from safetensors.torch import load_file from torch.nn.utils.rnn import pad_sequence from torchdiffeq import odeint from duration_predictor import SpeechLengthPredictor from f5_tts.infer.utils_infer import (chunk_text, convert_char_to_pinyin, hop_length, load_vocoder, preprocess_ref_audio_text, speed, target_rms, target_sample_rate, transcribe) # Import F5-TTS modules from f5_tts.model import CFM, DiT, UNetT from f5_tts.model.modules import MelSpec from f5_tts.model.utils import (default, exists, get_tokenizer, lens_to_mask, list_str_to_idx, list_str_to_tensor, mask_from_frac_lengths) # Import custom modules from unimodel import UniModel class DMOInference: """F5-TTS Inference wrapper class for easy text-to-speech generation.""" def __init__( self, student_checkpoint_path="", duration_predictor_path="", device="cuda", model_type="F5TTS_Base", # "F5TTS_Base" or "E2TTS_Base" tokenizer="pinyin", dataset_name="Emilia_ZH_EN", ): """ Initialize F5-TTS inference model. Args: student_checkpoint_path: Path to student model checkpoint duration_predictor_path: Path to duration predictor checkpoint device: Device to run inference on model_type: Model architecture type tokenizer: Tokenizer type ("pinyin", "char", or "custom") dataset_name: Dataset name for tokenizer cuda_device_id: CUDA device ID to use """ self.device = device self.model_type = model_type self.tokenizer = tokenizer self.dataset_name = dataset_name # Model parameters self.target_sample_rate = 24000 self.n_mel_channels = 100 self.hop_length = 256 self.real_guidance_scale = 2 self.fake_guidance_scale = 0 self.gen_cls_loss = False self.num_student_step = 4 # Initialize components self._setup_tokenizer() self._setup_models(student_checkpoint_path) self._setup_mel_spec() self._setup_vocoder() self._setup_duration_predictor(duration_predictor_path) def _setup_tokenizer(self): """Setup tokenizer and vocabulary.""" if self.tokenizer == "custom": tokenizer_path = self.tokenizer_path else: tokenizer_path = self.dataset_name self.vocab_char_map, self.vocab_size = get_tokenizer( tokenizer_path, self.tokenizer ) def _setup_models(self, student_checkpoint_path): """Initialize teacher and student models.""" # Model configuration if self.model_type == "F5TTS_Base": model_cls = DiT model_cfg = dict( dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4 ) elif self.model_type == "E2TTS_Base": model_cls = UNetT model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4) else: raise ValueError(f"Unknown model type: {self.model_type}") # Initialize UniModel (student) self.model = UniModel( model_cls( **model_cfg, text_num_embeds=self.vocab_size, mel_dim=self.n_mel_channels, second_time=self.num_student_step > 1, ), checkpoint_path="", vocab_char_map=self.vocab_char_map, frac_lengths_mask=(0.5, 0.9), real_guidance_scale=self.real_guidance_scale, fake_guidance_scale=self.fake_guidance_scale, gen_cls_loss=self.gen_cls_loss, sway_coeff=0, ) # Load student checkpoint checkpoint = torch.load(student_checkpoint_path, map_location="cpu") self.model.load_state_dict(checkpoint["model_state_dict"], strict=False) # Setup generator and teacher self.generator = self.model.feedforward_model.to(self.device) self.teacher = self.model.guidance_model.real_unet.to(self.device) self.scale = checkpoint["scale"] def _setup_mel_spec(self): """Initialize mel spectrogram module.""" mel_spec_kwargs = dict( target_sample_rate=self.target_sample_rate, n_mel_channels=self.n_mel_channels, hop_length=self.hop_length, ) self.mel_spec = MelSpec(**mel_spec_kwargs) def _setup_vocoder(self): """Initialize vocoder.""" self.vocos = load_vocoder(is_local=False, local_path="") self.vocos = self.vocos.to(self.device) def _setup_duration_predictor(self, checkpoint_path): """Initialize duration predictor.""" self.wav2mel = MelSpec( target_sample_rate=24000, n_mel_channels=100, hop_length=256, win_length=1024, n_fft=1024, mel_spec_type="vocos", ).to(self.device) self.SLP = SpeechLengthPredictor( vocab_size=2545, n_mel=100, hidden_dim=512, n_text_layer=4, n_cross_layer=4, n_head=8, output_dim=301, ).to(self.device) self.SLP.eval() self.SLP.load_state_dict( torch.load(checkpoint_path, map_location="cpu")["model_state_dict"] ) def predict_duration( self, pmt_wav_path, tar_text, pmt_text, dp_softmax_range=0.7, temperature=0 ): """ Predict duration for target text based on prompt audio. Args: pmt_wav_path: Path to prompt audio tar_text: Target text to generate pmt_text: Prompt text dp_softmax_range: softmax annliation range from rate-based duration temperature: temperature for softmax sampling (if 0, will use argmax) Returns: Estimated duration in frames """ pmt_wav, sr = torchaudio.load(pmt_wav_path) if sr != self.target_sample_rate: resampler = torchaudio.transforms.Resample(sr, self.target_sample_rate) pmt_wav = resampler(pmt_wav) if pmt_wav.size(0) > 1: pmt_wav = pmt_wav[0].unsqueeze(0) pmt_wav = pmt_wav.to(self.device) pmt_mel = self.wav2mel(pmt_wav).permute(0, 2, 1) tar_tokens = self._convert_to_pinyin(list(tar_text)) pmt_tokens = self._convert_to_pinyin(list(pmt_text)) # Calculate duration ref_text_len = len(pmt_tokens) gen_text_len = len(tar_tokens) ref_audio_len = pmt_mel.size(1) duration = int(ref_audio_len / ref_text_len * gen_text_len / speed) duration = duration // 10 min_duration = max(int(duration * dp_softmax_range), 0) max_duration = min(int(duration * (1 + dp_softmax_range)), 301) all_tokens = pmt_tokens + [" "] + tar_tokens text_ids = list_str_to_idx([all_tokens], self.vocab_char_map).to(self.device) text_ids = text_ids.masked_fill(text_ids == -1, self.vocab_size) with torch.no_grad(): predictions = self.SLP(text_ids=text_ids, mel=pmt_mel) predictions = predictions[:, -1, :] predictions[:, :min_duration] = float("-inf") predictions[:, max_duration:] = float("-inf") if temperature == 0: est_label = predictions.argmax(-1)[..., -1].item() * 10 else: probs = torch.softmax(predictions / temperature, dim=-1) sampled_idx = torch.multinomial( probs.squeeze(0), num_samples=1 ) # Remove the -1 index est_label = sampled_idx.item() * 10 return est_label def _convert_to_pinyin(self, char_list): """Convert character list to pinyin.""" result = [] for x in convert_char_to_pinyin(char_list): result = result + x while result[0] == " " and len(result) > 1: result = result[1:] return result def generate( self, gen_text, audio_path, prompt_text=None, teacher_steps=16, teacher_stopping_time=0.07, student_start_step=1, duration=None, dp_softmax_range=0.7, temperature=0, eta=1.0, cfg_strength=2.0, sway_coefficient=-1.0, verbose=False, ): """ Generate speech from text using teacher-student distillation. Args: gen_text: Text to generate audio_path: Path to prompt audio prompt_text: Prompt text (if None, will use ASR) teacher_steps: Number of teacher guidance steps teacher_stopping_time: When to stop teacher sampling student_start_step: When to start student sampling duration: Total duration (if None, will predict) dp_softmax_range: Duration predictor softmax range allowed around rate based duration temperature: Temperature for duration predictor sampling (0 means use argmax) eta: Stochasticity control (0=DDIM, 1=DDPM) cfg_strength: Classifier-free guidance strength sway_coefficient: Sway sampling coefficient verbose: Output sampling steps Returns: Generated audio waveform """ if prompt_text is None: prompt_text = transcribe(audio_path) # Predict duration if not provided if duration is None: duration = self.predict_duration( audio_path, gen_text, prompt_text, dp_softmax_range, temperature ) # Preprocess audio and text ref_audio, ref_text = preprocess_ref_audio_text(audio_path, prompt_text) audio, sr = torchaudio.load(ref_audio) if audio.shape[0] > 1: audio = torch.mean(audio, dim=0, keepdim=True) # Normalize audio rms = torch.sqrt(torch.mean(torch.square(audio))) if rms < target_rms: audio = audio * target_rms / rms if sr != self.target_sample_rate: resampler = torchaudio.transforms.Resample(sr, self.target_sample_rate) audio = resampler(audio) audio = audio.to(self.device) # Prepare text text_list = [ref_text + gen_text] final_text_list = convert_char_to_pinyin(text_list) # Calculate durations ref_audio_len = audio.shape[-1] // self.hop_length if duration is None: ref_text_len = len(ref_text.encode("utf-8")) gen_text_len = len(gen_text.encode("utf-8")) duration = ref_audio_len + int( ref_audio_len / ref_text_len * gen_text_len / speed ) else: duration = ref_audio_len + duration if verbose: print("audio:", audio.shape) print("text:", final_text_list) print("duration:", duration) print("eta (stochasticity):", eta) # Print eta value for debugging # Run inference with torch.inference_mode(): cond, text, step_cond, cond_mask, max_duration, duration_tensor = ( self._prepare_inputs(audio, final_text_list, duration) ) # Teacher-student sampling if teacher_steps > 0 and student_start_step > 0: if verbose: print( "Start teacher sampling with hybrid DDIM/DDPM (eta={})....".format( eta ) ) x1 = self._teacher_sampling( step_cond, text, cond_mask, max_duration, duration_tensor, # Use duration_tensor teacher_steps, teacher_stopping_time, eta, cfg_strength, verbose, sway_coefficient, ) else: x1 = step_cond if verbose: print("Start student sampling...") # Student sampling x1 = self._student_sampling( x1, cond, text, student_start_step, verbose, sway_coefficient ) # Decode to audio mel = x1.permute(0, 2, 1) * self.scale generated_wave = self.vocos.decode(mel[..., cond_mask.sum() :]) return generated_wave.cpu().numpy().squeeze() def generate_teacher_only( self, gen_text, audio_path, prompt_text=None, teacher_steps=32, duration=None, eta=1.0, cfg_strength=2.0, sway_coefficient=-1.0, ): """ Generate speech using teacher model only (no student distillation). Args: gen_text: Text to generate audio_path: Path to prompt audio prompt_text: Prompt text (if None, will use ASR) teacher_steps: Number of sampling steps duration: Total duration (if None, will predict) eta: Stochasticity control (0=DDIM, 1=DDPM) cfg_strength: Classifier-free guidance strength sway_coefficient: Sway sampling coefficient Returns: Generated audio waveform """ if prompt_text is None: prompt_text = transcribe(audio_path) # Predict duration if not provided if duration is None: duration = self.predict_duration(audio_path, gen_text, prompt_text) # Preprocess audio and text ref_audio, ref_text = preprocess_ref_audio_text(audio_path, prompt_text) audio, sr = torchaudio.load(ref_audio) if audio.shape[0] > 1: audio = torch.mean(audio, dim=0, keepdim=True) # Normalize audio rms = torch.sqrt(torch.mean(torch.square(audio))) if rms < target_rms: audio = audio * target_rms / rms if sr != self.target_sample_rate: resampler = torchaudio.transforms.Resample(sr, self.target_sample_rate) audio = resampler(audio) audio = audio.to(self.device) # Prepare text text_list = [ref_text + gen_text] final_text_list = convert_char_to_pinyin(text_list) # Calculate durations ref_audio_len = audio.shape[-1] // self.hop_length if duration is None: ref_text_len = len(ref_text.encode("utf-8")) gen_text_len = len(gen_text.encode("utf-8")) duration = ref_audio_len + int( ref_audio_len / ref_text_len * gen_text_len / speed ) else: duration = ref_audio_len + duration # Run inference with torch.inference_mode(): cond, text, step_cond, cond_mask, max_duration = self._prepare_inputs( audio, final_text_list, duration ) # Teacher-only sampling x1 = self._teacher_sampling( step_cond, text, cond_mask, max_duration, duration, teacher_steps, 1.0, eta, cfg_strength, sway_coefficient, # stopping_time=1.0 for full sampling ) # Decode to audio mel = x1.permute(0, 2, 1) * self.scale generated_wave = self.vocos.decode(mel[..., cond_mask.sum() :]) return generated_wave def _prepare_inputs(self, audio, text_list, duration): """Prepare inputs for generation.""" lens = None max_duration_limit = 4096 cond = audio text = text_list if cond.ndim == 2: cond = self.mel_spec(cond) cond = cond.permute(0, 2, 1) assert cond.shape[-1] == 100 cond = cond / self.scale batch, cond_seq_len, device = *cond.shape[:2], cond.device if not exists(lens): lens = torch.full((batch,), cond_seq_len, device=device, dtype=torch.long) # Process text if isinstance(text, list): if exists(self.vocab_char_map): text = list_str_to_idx(text, self.vocab_char_map).to(device) else: text = list_str_to_tensor(text).to(device) assert text.shape[0] == batch if exists(text): text_lens = (text != -1).sum(dim=-1) lens = torch.maximum(text_lens, lens) # Process duration cond_mask = lens_to_mask(lens) if isinstance(duration, int): duration = torch.full((batch,), duration, device=device, dtype=torch.long) duration = torch.maximum(lens + 1, duration) duration = duration.clamp(max=max_duration_limit) max_duration = duration.amax() # Pad conditioning cond = F.pad(cond, (0, 0, 0, max_duration - cond_seq_len), value=0.0) cond_mask = F.pad( cond_mask, (0, max_duration - cond_mask.shape[-1]), value=False ) cond_mask = cond_mask.unsqueeze(-1) step_cond = torch.where(cond_mask, cond, torch.zeros_like(cond)) return cond, text, step_cond, cond_mask, max_duration, duration def _teacher_sampling( self, step_cond, text, cond_mask, max_duration, duration, teacher_steps, teacher_stopping_time, eta, cfg_strength, verbose, sway_sampling_coef=-1, ): """Perform teacher model sampling.""" device = step_cond.device # Pre-generate noise sequence for stochastic sampling noise_seq = None if eta > 0: noise_seq = [ torch.randn(1, max_duration, 100, device=device) for _ in range(teacher_steps) ] def fn(t, x): with torch.inference_mode(): with torch.autocast(device_type="cuda", dtype=torch.float16): if verbose: print(f"current t: {t}") step_frac = 1.0 - t.item() step_idx = ( min(int(step_frac * len(noise_seq)), len(noise_seq) - 1) if noise_seq else 0 ) # Predict flow pred = self.teacher( x=x, cond=step_cond, text=text, time=t, mask=None, drop_audio_cond=False, drop_text=False, ) if cfg_strength > 1e-5: null_pred = self.teacher( x=x, cond=step_cond, text=text, time=t, mask=None, drop_audio_cond=True, drop_text=True, ) pred = pred + (pred - null_pred) * cfg_strength # Add stochasticity if eta > 0 if eta > 0 and noise_seq is not None: alpha_t = 1.0 - t.item() sigma_t = t.item() noise_scale = torch.sqrt( torch.tensor( (sigma_t**2) / (alpha_t**2 + sigma_t**2) * eta, device=device, ) ) return pred + noise_scale * noise_seq[step_idx] else: return pred # Initialize noise y0 = [] for dur in duration: y0.append(torch.randn(dur, 100, device=device, dtype=step_cond.dtype)) y0 = pad_sequence(y0, padding_value=0, batch_first=True) # Setup time steps t = torch.linspace( 0, 1, teacher_steps + 1, device=device, dtype=step_cond.dtype ) if sway_sampling_coef is not None: t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t) t = t[: (t > teacher_stopping_time).float().argmax() + 2] t = t[:-1] # Solve ODE trajectory = odeint(fn, y0, t, method="euler") if teacher_stopping_time < 1.0: # If early stopping, compute final step pred = fn(t[-1], trajectory[-1]) test_out = trajectory[-1] + (1 - t[-1]) * pred return test_out else: return trajectory[-1] def _student_sampling( self, x1, cond, text, student_start_step, verbose, sway_coeff=-1 ): """Perform student model sampling.""" steps = torch.Tensor([0, 0.25, 0.5, 0.75]) steps = steps + sway_coeff * (torch.cos(torch.pi / 2 * steps) - 1 + steps) steps = steps[student_start_step:] for step in steps: time = torch.Tensor([step]).to(x1.device) x0 = torch.randn_like(x1) t = time.unsqueeze(-1).unsqueeze(-1) phi = (1 - t) * x0 + t * x1 if verbose: print(f"current step: {step}") with torch.no_grad(): pred = self.generator( x=phi, cond=cond, text=text, time=time, drop_audio_cond=False, drop_text=False, ) # Predicted mel spectrogram output = phi + (1 - t) * pred x1 = output return x1